Zygote gradient accumulation?

I ported a neural network from pytorch to Flux and now I am immediately out of GPU memory.
I would like to better understand allocations in the backward pass in Flux/Zygote.

Say I have the following toy function:

function densenet(x0)
    x1 = x0
    x2 = x0 + x1
    x3 = x0 + x1 + x2
    x4 = x0 + x1 + x2 + x3
    x5 = x0 + x1 + x2 + x3 + x4
    return xn

Now I want to compute the pullback of densenet. I am interested in the peak memory usage if I do this.

Mathematically, the differential x̄3 at say x3 is the sum of the differentials x̄4…x̄n.
Is it correct, that x̄4…x̄n need to be alive in memory, for Zygote to compute x̄3?
I think in pytorch, this is not the case and x̄3 will be accumulated incrementally as the x̄i materialize. Is that correct?

If it’s not necessary for PyTorch, then in theory it’s not necessary for Zygote either since both are performing reverse mode AD. However, there are some implementation differences that likely affect memory usage. For one, Zygote does quite a bit of copying for certain operations. In-place accumulation is on the roadmap, but it’s not fully fleshed out yet. Also, PyTorch’s autograd engine is (according to their docs and my observations in practice) very proactive about buffer reuse and freeing allocations early. Lastly, CUDA.jl’s allocator integrates with the Julia GC and thus isn’t guaranteed to immediately “free” (or release to a pool/arena) temporary allocations like a reference counting system would. This doesn’t matter in most cases, but can lead to OOMs when you’re close to the VRAM limit and a library like CuDNN decides to carry out its own little allocation on the side.

To help evaluate where your network allocates on the backwards pass, you could use Zygote.@code_adjoint or @code_{typed/lowered/warntype} on the pullback itself. Do post your findings in #domain:ml or GitHub—it would be great if we could catalog and prioritize some of these performance gaps.


Thanks for the answer! Following your suggestion I opened an issue Inplace accumulation of gradients? · Issue #905 · FluxML/Zygote.jl · GitHub