# Memory usage of Zygote in list comprehensions

Hi! I’ve been running into memory/runtime issues with Zygote when using list comprehensions. I’d like to ask a simple question. Suppose `x` is an array of length n. When Zygote differentiates backwards through the following code
`f(x) = sum([x[i] - x[i+1] for i in range 1:2:N])`,

does it use O(n^2) memory/time, or O(n)?

The argument for O(n^2) is that a list comprehension is essentially the application of the function `f(i) = x[i] - x[i+1]` over `1:2:N`. For each entry in the list comprehension, Zygote will compute the sensitivity of the output with respect to the variables that `f` is a closure of (namely, `x`), which has size O(n). This leads to a total of O(n^2) memory. Essentially, Zygote does not realize that the Jacobian of the list comprehension is sparse.

The argument for O(n) is that Zygote does actually realize this – for example, by representing \Delta f for each entry with O(1) memory using a sparse vector with two non zero entries.

I haven’t been able to parse through the Zygote code well enough to understand which of these cases is true. More generally, my question is: will Zygote always be smart enough to not use order of magnitudes more memory than necessary when differentiating through list comprehensions? Or should I be wary of this issue, and try to use maps, broadcasting, etc. whenever possible?

It looks like Zygote does indeed take order of magnitudes more memory and time? MWE:

``````using Zygote
using BenchmarkTools
f(x) = sum([x[i+1] - x[i] for i in 1:2:length(x)])
x = randn(10000)
@btime f(x) # 4.089 μs (3 allocations: 39.16 KiB)
@btime gradient(f, x) # 77.189 ms (30058 allocations: 765.12 MiB)
``````

Does this mean that the above pattern (a list comprehension / map over an iterator, where a large array is accessed by an index) is not a good idea in Zygote – and is this emphasized anywhere? It seems like an easy mistake to make.

It’s the indexing that’s the problem. It doesn’t really notice that this is done inside a comprehension.

Naiively the gradient of `y = x[i]` needs to be roughly `dx = zero(x); dx[i] = dy; dx`, which allocates a whole array. Zygote now has a very simple sparse array for this, for one entry.

But the next problem is adding all the contributions `dx`. At present each step just does `dx + dx2`, which allocates, so I think it’s N^2.

PR#981 wanted to make it mutate instead, which in principle ought to make it O(N) I think. But in practice this only gets a factor of 2 here, so other things must be going wrong still. See also #644. This also might be solved other ways, e.g. by storing in-place-thunks & applying them all at the end.

None of this is well-documented, sadly. #1077 is looking for someone to start…

``````julia> @btime f(x) # 4.089 μs (3 allocations: 39.16 KiB)
4.816 μs (3 allocations: 39.12 KiB)
13.44747151439001

99.440 ms (30067 allocations: 764.81 MiB)   # latest, v0.6.29
45.855 ms (20069 allocations: 383.19 MiB)   # with PR 981
([-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0  …  -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0],)

# Another way:
julia> using Tullio

julia> g(x) = @tullio _ := x[2i] - x[2i-1]
g (generic function with 1 method)

julia> @btime g(x)
1.212 μs (1 allocation: 16 bytes)
13.447471514389736

6.792 μs (21 allocations: 78.62 KiB)
([-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0  …  -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0],)

julia> x5 = randn(100_000);  # 10x

julia> @btime g(\$x5);
11.666 μs (0 allocations: 0 bytes)

julia> @btime gradient(g, \$x5);  # linear?
53.333 μs (6 allocations: 781.36 KiB)

julia> @btime gradient(f, \$x5);  # N^2 memory
41.749 s (200069 allocations: 37.27 GiB)

# A third way:
julia> h(x) = sum(view(reshape(x,2,:),2,:) - view(reshape(x,2,:),1,:));

julia> @btime h(\$x)
3.104 μs (6 allocations: 39.30 KiB)
13.44747151439001

13.875 μs (15 allocations: 195.81 KiB)

134.750 μs (15 allocations: 1.91 MiB)

``````
2 Likes

That makes sense, thank you! I understand that AD is still in a state of flux, and it’s good to know what the current landscape is in terms of PRs. I’d love to contribute once I’ve gotten the hang of Julia a bit more.

I want to mention a nefarious case which I’m struggling to rewrite using maps, broadcasting, and so on (which has been my usual strategy for avoiding list comprehensions). I have an array of values, for example `[1,2,3,4,5,6]`, and an array of lengths, for example, `[2,1,3]`. I’d like to turn the values into an array of arrays of the specified lengths, which would be `[[1,2], [3], [4,5,6]]` in the example, since this is the format needed for the next stage of the pipeline. But I have no idea how to do this without array indexing

What’s frustrating is that this function is actually just a bijection; there’s no interesting adjoints going on at all. Yet I’m unable to think of a way of writing it that is fast (i.e. time does not increase by orders of magnitude in backwards pass)! Am I missing something here?

I think it really pays to write gradients for such functions by hand. The 5-minute version is something like this:

``````function divideinto(xs::AbstractVector, ls::AbstractVector{<:Integer})
sum(ls) == length(xs) || throw("not a nice division!")
i = 0
[xs[i+1:(i+=l)] for l in ls]
end

divideinto([1,2,3,4,5,6], [2,1,3])