I have noticed that when I use Zygote to do automatic differentiation in my code (typically I use Zygote to obtain the gradient of my loss function and pass it to Opim.optimize to minimise the loss), I often get high number of memory allocations. I suspect I might be doing something wrong, but I haven’t managed to figure it out yet.
During my investigations, I created the following mock example:
using Zygote, BenchmarkTools, Random, LinearAlgebra
function f(A, x) # mock loss function
aux = zero(eltype(x))
for a in A
aux += dot(x, a*x)
function runme() # instantiate parameters and benchmark calls to mock loss
rg = MersenneTwister(1)
helper(M) = M'*M # make matrix positive definite
D = 500 # number of dimensions of x in function f
K = 10 # number of matrices in function f
A = [helper(randn(rg, D, D)) for k in 1:K]
x₀ = randn(rg, D) # randomly picked point at which to evaluate mock loss function and its gradient
g(x) = f(A, x) # convenience function
@btime Zygote.gradient($g, $x₀)
Using Julia 1.9.3 and Zygote v0.6.67, I get the following results:
for the mock-loss evaluation and its gradient respectively.
My question: considering that the call to the mock loss function is fairly fast and allocates fairly little, is it reasonable that the gradient calculation results in (what seems to me) such high memory consumption?
FWIW I recall seeing something similar when using ForwardDiff—most memory allocations than I thought ought to be necessary. In that case (since speed mattered in that part of the code) I eventually resorted to just hand-rolling the gradients.
But, I’m still not confident that there isn’t a way around this. In particular it seems that there’s no technical barrier preventing an autodiff package from pre-allocating any memory needed for gradients, and repeatedly using the same chunks. I believe that’s effectively what jax accomplishes, for example.
This reassures me that I am not doing something terribly wrong.
This is a great observation that I totally missed. Indeed, this explains the high allocations. However this confuses me: why is it that we should worry about allocations concerning A? I would have thought that A is already allocated in the body of function runme().
Ah I may have mis-read this. My answer is why the first of these allocates so much, but maybe your question is why the second isn’t any faster:
@btime gradient(f, $A, $x); # must allocate at least copy.(A)
@btime gradient(x -> f($A, x), $x); # here A is constant
Zygote computes but discards the gradient of A. I think the hope was that the compiler would learn to eliminate this work. There was also an elaborate system of delaying work (via ChainRules’s thunks) which isn’t used.
In principle Tracker / ReverseDiff know that A is constant. But both allocate more here, so they don’t seem to be exploiting this.
I know Enzyme keeps track of activity. Haven’t made it run today but you could try Enzyme.gradient(Reverse, x -> f(A, x), x).
Some slight corrections/clarifications on these points.
ADs are not magic and I don’t know of any which make guarantees about whether/how much they’ll allocate on top of what your code does. JAX definitely doesn’t, and you can see that if you don’t use jax.jit (which itself is a black box and may or may not eliminate any allocations).
That issue is from 2019, and since then a lot of work has been put into making Zygote more type stable as well as optimizing rules. Now, what hasn’t changed is that Zygote will generate inherently type unstable code whenever it encounters control flow such as conditionals, loops, etc. This results in a lot of small allocations from dynamic dispatch, but those may or may not have a performance impact depending on whether they occur on some hot codepath vs a less frequently-run outer loop.
This doesn’t help if what you’re concerned about is total memory usage (indeed I don’t think jax’s AD makes any guarantees there), but there is indeed a guarantee regarding how much time will be spent allocating memory post-compilation: none.
Yes, in the sense that XLA will generate a static graph and hoist out all the buffers required ahead of time. It can also do tricks such as buffer reuse since JAX arrays are semantically immutable.
However, there’s some nuance in that’s not an apples-to-apples comparison with what’s being reported for allocations when you run Julia code. Importantly, primitive and operations and their registered jvp/vjp rules are allowed to do whatever they want under the hood. This includes dynamic memory allocations! Otherwise, it would be difficult to say use cuDNN for conv ops on GPU, because such libraries rely on dynamically allocated workspaces which aren’t represented in the XLA graph. No memory allocation happening post-compilation is only true from the high-level perspective of the XLA compiler staring at the graph DSL.