Nested AD with Lux etc

Nested AD (Starting v0.5.38)

Starting v0.5.38, Lux automatically captures some common possibilities of AD calls inside loss functions on lux layers and converts them into a faster version to do a JVP over gradient (instead of the default reverse over reverse). In short, this finally handles several complaints we have had over the years of not being able to handle nested AD of Neural Networks efficiently. See Nested Automatic Differentiation | Lux.jl Documentation

function loss_function2(model, t, ps, st)
    smodel = StatefulLuxLayer(model, ps, st)
    ŷ = only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, t))
    return sum(abs2, ŷ .- cos.(t))
end

model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
    Dense(12 => 1))
ps, st = Lux.setup(Xoshiro(0), model)
t = rand(Xoshiro(0), Float32, 1, 16)

_, ∂t, ∂ps, _ = Zygote.gradient(loss_function2, model, t, ps, st)
14 Likes

Very cool! This sounds similar to the DifferentiateWith(f, backend) I just introduced in DifferentiationInterface.jl, maybe it can help?

https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/dev/api/#DifferentiationInterface.DifferentiateWith

3 Likes

It is actually more similar to secondorder in DI than this one. ~This one~ The implementation in Lux doesn’t really change the backend exactly. Let’s say you had an inner most Zygote.gradient it will just make sure that if you call Zygote.gradient over this then we do a single JVP to get the gradients.

The tricky thing that is handled here is that for a general closure you would drop the parameter gradients, but if we “own” the type (here LuxStatefulLayer) (one of the reasons BatchedRoutines has an explicit signature of batched_*(f, x, p)) we can restructure the function call to also compute the parameter gradients in a single call.

1 Like

Can you explain why JVP is supposed to be more performant than two backward mode? In theory at least the latter should be quasi-optimal.

Reverse mode always has more overhead than forward mode because it generally implies some form of delayed operation and any caching and the like is just overhead. After one reverse mode you take R^n → R^1 to R^n → R^n, so there is no longer an inherent scaling advantage to reverse mode because the next derivative if of something square, so you might as well use the one with less inherent overhead. Hence you might as well use forward mode at that level.

IIRC Griewank’s book has a proof or theorem or something around the fact that it’s always optimal to do one reverse and many forward.

But if you are differentiating a scalar-valued function of the gradient you go back to R, so reverse-mode should be advantageous again.

I was referring to the setting @stevengj mentions. I thought the context is that we are differentiating a loss that depends on a gradient. If I misunderstood this then we can just let this go.

I can appreciate the overhead argument but I would not want a default in an ML framework based on that and rather focus on fixing the overhead. Naively - If I implement reverse over reverse manually then I never have this overhead problem so I’m assuming it must be possible.

Right so in this case reverse mode should work as well, but from a practical standpoint even reverse-over-reverse has higher overhead (atleast in JAX where nested reverse works well). There was a blog post in ICLR which compared HVP costs with different nesting orders and forward over reverse was generally faster for large models like BERT and ViTs.

Also in this case, it is just a single JVP over a gradient call, so I don’t see why reverse mode should have much benefits here (considering that it generally has an increased overhead wrt caching intermediates even if you let other overheads go).

Because for scalar outputs reverse-mode AD should be proportional to the cost of the forward computation, whereas forward-mode AD is proportional to the cost of the forward computation times the number of parameters. This is why reverse mode should eventually always win if you have enough parameters and only a single scalar output (or a small number of scalar outputs).

Even if you implement reverse-mode manually there is generally a storage overhead, because reverse-mode differentiation requires you to keep the results of all the intermediate calculations so that you can backpropagate after the forward calculation is complete.

2 Likes

Because for scalar outputs reverse-mode AD should be proportional to the cost of the forward computation, whereas forward-mode AD is proportional to the cost of the forward computation times the number of parameters.

Correct me if I am wrong here, that is coming from propagating # of parameters dual numbers independently to avoid perturbation confusion. Here you can rewrite it as a directional derivative so the forward mode is a single forward computation.

1 Like

That’s why I was making a strong distinction from this feature. This replaces a Zygote.gradient call with a ForwardDiff.gradient call. However, for the case I mentioned above,

Zygote.gradient(x) do
     ... = Zygote.gradient(f, x)
     return scalar
end

gets converted to a

    x_duals = ....
    Zygote.gradient(f, x_duals)
1 Like

I think I see what you are saying. More explicitly, suppose we want the gradient of a scalar-valued function h(x) = g(\left. \nabla f \right|_x) of the gradient \nabla f of a scalar-valued f(x) of x \in \mathbb{R}^n. Then

  1. For the calculation of h(x), we use reverse-mode to compute \nabla f and then plug it into g
  2. For the calculation of \nabla h, the chain rule corresponds to first linearizing g(\left. \nabla f \right|_{y}) \approx g(\left. \nabla f \right|_{x}) + (\left. \nabla g \right|_{\left. \nabla f \right|_{x}})^T \left[ \left. \nabla f \right|_{y} - \left. \nabla f \right|_{x} \right] and then taking the gradient of the second term with respect to y (& evaluated at y = x). But the latter expression (\nabla g)^T \left. \nabla f \right|_y is a single directional derivative of f, so we can compute this with forward mode using a single dual number (cost comparable to evaluating f), as a scalar derivative \left. \frac{d}{d\alpha} f(y + \alpha \nabla g) \right|_{\alpha=0}, and then apply reverse-over-forward to find the gradient \nabla h. Equivalently, you can interchange derivatives to do forward-over-reverse: \nabla h = \left. \frac{d}{d\alpha} \left( \left. \nabla f \right|_{x + \alpha \nabla g} \right) \right|_{\alpha=0}.

Here is an explicit example.

using ForwardDiff, Zygote, LinearAlgebra
f(x) = 1/norm(x)    # example ℝⁿ → ℝ function
g(∇f) = sum(∇f)^3   # example ℝⁿ → ℝ function
h(x) = g(Zygote.gradient(f, x)[1])
function ∇h(x)
    ∇f(y) = Zygote.gradient(f, y)[1]
    ∇g = Zygote.gradient(g, ∇f(x))[1]
    return ForwardDiff.derivative(α -> ∇f(x + α*∇g), 0)
end

gives

julia> x = randn(5); δx = randn(5) * 1e-8;

julia> h(x)
-0.005284687528953334

julia> ∇h(x)
5-element Vector{Float64}:
 -0.006779692698531759
  0.007176439898271982
 -0.006610264199241697
 -0.0012162087082746558
  0.007663756720005014

julia> ∇h(x)'δx  # directional derivative
-3.0273434457397667e-10

julia> h(x+δx) - h(x)  # finite-difference check
-3.0273433933303284e-10

Note that I used the forward-over-reverse formulation above, because Zygote can’t currently differentiate a ForwardDiff.derivative while the converse is fine.

5 Likes

Thanks for clarifying. This makes a lot of sense. I never thought of mixing reverse and forward in this way.

I feel like this is a commonplace enough pattern (while at the same time being very nonobvious!) that it should be in a tutorial somewhere if it isn’t supported automatically…

1 Like

Partly captured at Zygote.jl/src/lib/forward.jl at c0daccded5b9f91d31ceb889e4a97e74dd722a4e · FluxML/Zygote.jl · GitHub but with a bad implementation of fully materializing the hessian.

The main issue however are the parameter gradients which commonly show up in most optimization setups, which can’t be captured directly since you need a closure over the function. Two ways to handle would be:

  1. Have a jacobian/gradient(f, x, p) which takes p as input but only computes the jacobian or gradient wrt x. This is what GitHub - LuxDL/BatchedRoutines.jl: Routines for batching regular code and make them fast! implements. The downside is now you need to rewrite all your code to use a new API.

  2. If you “own” the function, in this case StatefulLuxLayer for Lux, you can capture the Zygote.gradient or other similar AD calls without doing type-piracy and rewrite the operations to be efficient (as done in Lux.jl/ext/LuxZygoteExt.jl at 36b362a482002a2525024d9254cec184315fd5ac · LuxDL/Lux.jl · GitHub). Currently I do it only for Zygote and ForwardDiff but soon DI will be supported as well Capture DifferentiationInterface calls for efficient Nested AD · Issue #600 · LuxDL/Lux.jl · GitHub.

We could consider an implementation in DI for this as well where to do nested AD users use a f_pseudo_closure = FixedParamsStruct(f, ps); DI.jacobian(<backend>, f_pseudo_closure, x) and then DI defines the proper chain rules for it.

Doesn’t Zygote.gradient(x -> f(x, p), x) do this already? Why do you need p as an explicit argument?

It does but it will try to do reverse over reverse which mostly doesn’t work with Zygote (except simple cases) and has terrifying compile times.

(And I am not sure if we can extract the parameters out of the closure in a nice way to rewrite the function call)

The parameters are needed to define the CRC.rrule, I don’t think we can return a Tangent type for a general closure. cc @oxinabox can correct me here if I it is possible

You mean adding a kind of partial stop_gradient to DI?

Not stop_gradient, but you would need to unwrap the call for the reverse pass to define the rrule nicely. See Lux.jl/ext/LuxZygoteExt.jl at 36b362a482002a2525024d9254cec184315fd5ac · LuxDL/Lux.jl · GitHub

Not sure i understand the question.
In general there is nothing particularly problematic about defining a tangent type for a closure.
Diffractor does it under the hood for one of the tests.
Closures are basically just named tuples with with call overloaded.

Can be a bit annoying for dispatch reasons I guess