Nested AD with Lux etc

Let me clarify. Say we are defining the following loss function

function loss_function(m, x, ps, st)
    f = x -> m(x, ps, st)
    return sum(abs2, only(Zygote.gradient(f, x)))
end

If we do a Zygote.gradient(loss_function, ....) it will be very slow in the best case and error in the vast majority of cases, so we want to use JVP over Zygote.gradient to compute the gradients wrt x and ps.

function CRC.rrule(cfg::..., ::typeof(Zygote.gradient), f::F, x) where {F}
     # Here `f` is potentially a closure
     # Q1: How do we extract the values like `m`, `ps` and `st`  in `f`?
     # Let's proceed assuming we did it
     y = Zygote.gradient(f, x)
     function gradient_pullback(Delta)
           # We use our jvp trick to compute dm, dps, dst, dx
           return NoTangent(), NoTangent(), Tangent for `f`, dx
           # Q2: What is the appropriate tangent for `f`
     end
     return y, gradient_pullback
end

Since I did not know how to solve this what Lux does and I am proposing for DI to do is., we have a ParamsStruct

function Zygote.gradient(f::ParamsStruct, x)
     return __my_custom_gradient_func(f.func, x, f.ps)  # In the forward pass we just call Zygote.gradient
end

# Defining rrule is now simple as
function CRC.rrule(cfg::..., ::typeof(__my_custom_gradient_func), f::F, x, ps) where {F}
     # Use the JVP trick here
end

It seems like a very specific hack, and I wouldn’t want DI to become a collection of such hacks. Can we maybe take a step back and figure out the general principle behind it? Switching the mode for one of two input arguments?

It is not really a hack nor is it switching the backend entirely. The claim is doing reverse over reverse comes with limited benefits and none of the current AD tools in julia can really handle this (except maybe enzyme). However, Dual Numbers pretty much works with with any of the AD tools that we have.

Let’s say you have an outer_func which takes N inputs. Then an inner function inner_func closes over M < N inputs and computes the gradient for M - N of the inputs inside outer_func and performs some arbitrary operation returning a scalar.

If I call Zygote.gradient(outer_func, args...), I expect gradients wrt all the args... (unlike what Zygote does to the forwarddiff operations which drops gradients (with a warning)). If you try reverse mode over the inner gradient function, you are pretty much screwed.

So what we should be doing instead is unwrap the closure and construct Dual Numbers using the input to the pullback operator and compute the inner gradient for all the N inputs (note we only changed the outer gradient call with a JVP, see Nested AD with Lux etc - #12 by stevengj for why this is true), extract the partials and return the gradients via an rrule. Now we could have easily done it without unwrapping the closure, but it seems messy to return a tangent type corresponding to the closure.

Do you have a link to where it does it?

I see what you mean by the NamedTuple part:

julia> f = let dd=dd
           x -> dd .+ x
       end
#3 (generic function with 1 method)

julia> fieldnames(typeof(f))
(:dd,)

Here is the test

It just uses the normal path for constructing objects.
and then the normal path for differentiating calls.

 # Q1: How do we extract the values like `m`, `ps` and `st`  in `f`?

f.m, f.ps, f.st, etc.
This is, as I think you have worked out, what i mean by closures being callable NamedTuples.

# We use our jvp trick to compute dm, dps, dst, dx
# Q2: What is the appropriate tangent for f`

It is Tangent{F}(m=dm, ps=dps, st=dst)

1 Like

For future reference, here is a manual implementation of forward-over-reverse calculation with parameter gradients (mixed second derivatives). In particular, if you have a scalar-valued h(x,p) = g(\nabla_x f) for some scalar-valued f(x,p), then one can similarly derive:

\left. \nabla_p h \right|_{x,p} = \left. \frac{\partial}{\partial\alpha} \left. \nabla_p f \right|_{x + \alpha \left. \nabla g \right|_{z},p} \right|_{\alpha = 0} \, ,

where z = \left. \nabla_x f \right|_{x,p}.

Here is an example calculation via ForwardDiff over Zygote, along with a finite-difference check:

julia> using ForwardDiff, Zygote, LinearAlgebra

julia> f(x,p) = sum(p)^2/norm(x) + p[1]*x[2];  # example function

julia> g(∇ₓf) = sum(∇ₓf)^3;   # example ℝⁿ → ℝ function

julia> h(x,p) = g(Zygote.gradient(x -> f(x,p), x)[1]);  # evaluate h by reverse mode

julia> function ∇ₚh(x,p)
           ∇ₚf(y,q) = Zygote.gradient(u -> f(y,u), q)[1]
           ∇g = Zygote.gradient(g, Zygote.gradient(x -> f(x,p), x)[1])[1]
           return ForwardDiff.derivative(α -> ∇ₚf(x + α*∇g, p), 0)
       end;

julia> x = randn(5); p = randn(4); δp = randn(4) * 1e-8;

julia> h(x,p)
-6.538498714556666e-5

julia> ∇ₚh(x,p)
4-element Vector{Float64}:
  0.0025659422776640596
 -0.0023030596877005173
 -0.0023030596877005173
 -0.0023030596877005173

julia> h(x,p+δp) - h(x,p)   # finite-difference check
-5.696379205464251e-11

julia> ∇ₚh(x,p)'δp           # exact directional derivative
-5.6963775418992394e-11

PS. This also seems like a good example of the clarity benefits of Unicode variable names for mapping math to code.

4 Likes

I still don’t have a very thorough understanding of what’s needed here, but if @avikpal wants to submit a PR to DifferentiationInterface.jl doing that, I’ll happily review and discuss more