Just-in-time gradient optimization idea

Disclaimer: The idea might look simple in theory, but it might be very hard to implement.

Imagine you had a machine learning model comprising multiple batches of parameters. Normally, you differentiate through the entirety of the parameters and then do a gradient descent or other gradient-based optimization for the value.

However, I realized something. This process unnecessarily uses and accesses memory.

One realization is that once backpropagation reaches the parameter set, and you’ve backpropagated through every path that led to the set of parameters, those parameters won’t be used anymore. That is because, imagine a parameter set X is used in f(X,Y) and g(X,Z), you need to have backpropagated through the results of the two functions. The gradient calculations that depend on the value of X have already propagated back to Y and Z. This means that you’re free to mutate X too, just as the gradient of X arrived.

This means that instead of writing the gradient of X to memory to be used later, you can update X now!

This saves memory and avoids unnecessary accesses to the parameter set and the gradients.

What should we do with this idea?

I’m not sure what can be done about it. Automatic differentiation is quite hard already as it is, and adding this to the mix could make it even harder. Still, if you want to do something with it, I’m welcome.

[EDIT]: This has been done before in PyTorch. However, given Julia’s autodiff ecosystem, it might be difficult to apply this idea. How to save memory by fusing the optimizer step into the backward pass — PyTorch Tutorials 2.7.0+cu126 documentation

I’m pretty sure we allocate the gradient once and mutate it every iteration already no ? Maybe not with Zygote but I’m not sure, maybe not completly indeed but I have no idea where those allocs could come from :

julia> using Lux, Enzyme

julia> model = Chain(Dense(1=>16,tanh),Dense(16=>16,tanh),Dense(16=>1));

julia> using Random; rng = Random.default_rng(123);

julia> ps,st = Lux.setup(rng,model);

julia> loss(m,p,s,x,y) = sum(abs2,m(x,p,s)[1].-y);

julia> x = rand(Float32,1,1000);

julia> y = rand(Float32,1,1000);

julia> g = Enzyme.make_zero(ps);

julia> Enzyme.autodiff(Reverse,loss,Const(model),Duplicated(ps,g),Const(st),Const(x),Const(y));

julia> using BenchmarkTools

julia> @btime loss($model,$ps,$st,$x,$y);
  24.299 μs (16 allocations: 133.41 KiB)

julia> @btime Enzyme.autodiff(Reverse,loss,Const($model),Duplicated($ps,$g),Const($st),Const($x),Const($y));
  124.533 μs (79 allocations: 523.35 KiB)

This is exactly what happens if you are using Lux + Reactant with Utilities | Lux.jl Docs and the return_gradients in single_train_step! is set to Val(false).

Essentially Reactant/XLA will automatically analyze the lifetime of an intermediate buffer and reuse the memory for operations if possible. This is one of the many reasons why we outperform (Lux.jl/perf at main · LuxDL/Lux.jl · GitHub) CUDA.jl baselines.

maybe not completly indeed but I have no idea where those allocs could come from

We need to allocate buffers to hold partial derivatives from the intermediate ops. Unless you cache all the intermediate allocations, autodiff call will always allocate.

5 Likes