Disclaimer: The idea might look simple in theory, but it might be very hard to implement.
Imagine you had a machine learning model comprising multiple batches of parameters. Normally, you differentiate through the entirety of the parameters and then do a gradient descent or other gradient-based optimization for the value.
However, I realized something. This process unnecessarily uses and accesses memory.
One realization is that once backpropagation reaches the parameter set, and you’ve backpropagated through every path that led to the set of parameters, those parameters won’t be used anymore. That is because, imagine a parameter set X is used in f(X,Y) and g(X,Z), you need to have backpropagated through the results of the two functions. The gradient calculations that depend on the value of X have already propagated back to Y and Z. This means that you’re free to mutate X too, just as the gradient of X arrived.
This means that instead of writing the gradient of X to memory to be used later, you can update X now!
This saves memory and avoids unnecessary accesses to the parameter set and the gradients.
What should we do with this idea?
I’m not sure what can be done about it. Automatic differentiation is quite hard already as it is, and adding this to the mix could make it even harder. Still, if you want to do something with it, I’m welcome.
[EDIT]: This has been done before in PyTorch. However, given Julia’s autodiff ecosystem, it might be difficult to apply this idea. How to save memory by fusing the optimizer step into the backward pass — PyTorch Tutorials 2.7.0+cu126 documentation