Nested AD with Lux etc

For future reference, here is a manual implementation of forward-over-reverse calculation with parameter gradients (mixed second derivatives). In particular, if you have a scalar-valued h(x,p) = g(\nabla_x f) for some scalar-valued f(x,p), then one can similarly derive:

\left. \nabla_p h \right|_{x,p} = \left. \frac{\partial}{\partial\alpha} \left. \nabla_p f \right|_{x + \alpha \left. \nabla g \right|_{z},p} \right|_{\alpha = 0} \, ,

where z = \left. \nabla_x f \right|_{x,p}.

Here is an example calculation via ForwardDiff over Zygote, along with a finite-difference check:

julia> using ForwardDiff, Zygote, LinearAlgebra

julia> f(x,p) = sum(p)^2/norm(x) + p[1]*x[2];  # example function

julia> g(∇ₓf) = sum(∇ₓf)^3;   # example ℝⁿ → ℝ function

julia> h(x,p) = g(Zygote.gradient(x -> f(x,p), x)[1]);  # evaluate h by reverse mode

julia> function ∇ₚh(x,p)
           ∇ₚf(y,q) = Zygote.gradient(u -> f(y,u), q)[1]
           ∇g = Zygote.gradient(g, Zygote.gradient(x -> f(x,p), x)[1])[1]
           return ForwardDiff.derivative(α -> ∇ₚf(x + α*∇g, p), 0)
       end;

julia> x = randn(5); p = randn(4); δp = randn(4) * 1e-8;

julia> h(x,p)
-6.538498714556666e-5

julia> ∇ₚh(x,p)
4-element Vector{Float64}:
  0.0025659422776640596
 -0.0023030596877005173
 -0.0023030596877005173
 -0.0023030596877005173

julia> h(x,p+δp) - h(x,p)   # finite-difference check
-5.696379205464251e-11

julia> ∇ₚh(x,p)'δp           # exact directional derivative
-5.6963775418992394e-11

PS. This also seems like a good example of the clarity benefits of Unicode variable names for mapping math to code.

7 Likes