If you want the gradient of a scalar-valued function that depends on the gradient of another scalar-valued function, you can use forward-over-reverse combining ForwardDiff with e.g. Zygote or Enzyme or ReverseDiff. See:
(You can also use this approach for general Hessians, but it was less obvious to me that it is efficient for scalar-valued functions.)