For future reference, here is a manual implementation of forward-over-reverse calculation with parameter gradients (mixed second derivatives). In particular, if you have a scalar-valued h(x,p) = g(\nabla_x f) for some scalar-valued f(x,p), then one can similarly derive:
\left. \nabla_p h \right|_{x,p} =
\left. \frac{\partial}{\partial\alpha} \left. \nabla_p f \right|_{x + \alpha \left. \nabla g \right|_{z},p} \right|_{\alpha = 0} \, ,
where z = \left. \nabla_x f \right|_{x,p}.
Here is an example calculation via ForwardDiff over Zygote, along with a finite-difference check:
julia> using ForwardDiff, Zygote, LinearAlgebra
julia> f(x,p) = sum(p)^2/norm(x) + p[1]*x[2]; # example function
julia> g(∇ₓf) = sum(∇ₓf)^3; # example ℝⁿ → ℝ function
julia> h(x,p) = g(Zygote.gradient(x -> f(x,p), x)[1]); # evaluate h by reverse mode
julia> function ∇ₚh(x,p)
∇ₚf(y,q) = Zygote.gradient(u -> f(y,u), q)[1]
∇g = Zygote.gradient(g, Zygote.gradient(x -> f(x,p), x)[1])[1]
return ForwardDiff.derivative(α -> ∇ₚf(x + α*∇g, p), 0)
end;
julia> x = randn(5); p = randn(4); δp = randn(4) * 1e-8;
julia> h(x,p)
-6.538498714556666e-5
julia> ∇ₚh(x,p)
4-element Vector{Float64}:
0.0025659422776640596
-0.0023030596877005173
-0.0023030596877005173
-0.0023030596877005173
julia> h(x,p+δp) - h(x,p) # finite-difference check
-5.696379205464251e-11
julia> ∇ₚh(x,p)'δp # exact directional derivative
-5.6963775418992394e-11
PS. This also seems like a good example of the clarity benefits of Unicode variable names for mapping math to code.