Nested AD with Lux etc

I think I see what you are saying. More explicitly, suppose we want the gradient of a scalar-valued function h(x) = g(\left. \nabla f \right|_x) of the gradient \nabla f of a scalar-valued f(x) of x \in \mathbb{R}^n. Then

  1. For the calculation of h(x), we use reverse-mode to compute \nabla f and then plug it into g
  2. For the calculation of \nabla h, the chain rule corresponds to first linearizing g(\left. \nabla f \right|_{y}) \approx g(\left. \nabla f \right|_{x}) + (\left. \nabla g \right|_{\left. \nabla f \right|_{x}})^T \left[ \left. \nabla f \right|_{y} - \left. \nabla f \right|_{x} \right] and then taking the gradient of the second term with respect to y (& evaluated at y = x). But the latter expression (\nabla g)^T \left. \nabla f \right|_y is a single directional derivative of f, so we can compute this with forward mode using a single dual number (cost comparable to evaluating f), as a scalar derivative \left. \frac{d}{d\alpha} f(y + \alpha \nabla g) \right|_{\alpha=0}, and then apply reverse-over-forward to find the gradient \nabla h. Equivalently, you can interchange derivatives to do forward-over-reverse: \nabla h = \left. \frac{d}{d\alpha} \left( \left. \nabla f \right|_{x + \alpha \nabla g} \right) \right|_{\alpha=0}.

Here is an explicit example.

using ForwardDiff, Zygote, LinearAlgebra
f(x) = 1/norm(x)    # example ℝⁿ → ℝ function
g(∇f) = sum(∇f)^3   # example ℝⁿ → ℝ function
h(x) = g(Zygote.gradient(f, x)[1])
function ∇h(x)
    ∇f(y) = Zygote.gradient(f, y)[1]
    ∇g = Zygote.gradient(g, ∇f(x))[1]
    return ForwardDiff.derivative(α -> ∇f(x + α*∇g), 0)
end

gives

julia> x = randn(5); δx = randn(5) * 1e-8;

julia> h(x)
-0.005284687528953334

julia> ∇h(x)
5-element Vector{Float64}:
 -0.006779692698531759
  0.007176439898271982
 -0.006610264199241697
 -0.0012162087082746558
  0.007663756720005014

julia> ∇h(x)'δx  # directional derivative
-3.0273434457397667e-10

julia> h(x+δx) - h(x)  # finite-difference check
-3.0273433933303284e-10

Note that I used the forward-over-reverse formulation above, because Zygote can’t currently differentiate a ForwardDiff.derivative while the converse is fine.

8 Likes