Memory usage of Zygote in list comprehensions

Hi! I’ve been running into memory/runtime issues with Zygote when using list comprehensions. I’d like to ask a simple question. Suppose x is an array of length n. When Zygote differentiates backwards through the following code
f(x) = sum([x[i] - x[i+1] for i in range 1:2:N]),

does it use O(n^2) memory/time, or O(n)?

The argument for O(n^2) is that a list comprehension is essentially the application of the function f(i) = x[i] - x[i+1] over 1:2:N. For each entry in the list comprehension, Zygote will compute the sensitivity of the output with respect to the variables that f is a closure of (namely, x), which has size O(n). This leads to a total of O(n^2) memory. Essentially, Zygote does not realize that the Jacobian of the list comprehension is sparse.

The argument for O(n) is that Zygote does actually realize this – for example, by representing \Delta f for each entry with O(1) memory using a sparse vector with two non zero entries.

I haven’t been able to parse through the Zygote code well enough to understand which of these cases is true. More generally, my question is: will Zygote always be smart enough to not use order of magnitudes more memory than necessary when differentiating through list comprehensions? Or should I be wary of this issue, and try to use maps, broadcasting, etc. whenever possible?

It looks like Zygote does indeed take order of magnitudes more memory and time? MWE:

using Zygote
using BenchmarkTools
f(x) = sum([x[i+1] - x[i] for i in 1:2:length(x)])
x = randn(10000)
@btime f(x) # 4.089 μs (3 allocations: 39.16 KiB)
@btime gradient(f, x) # 77.189 ms (30058 allocations: 765.12 MiB)

Does this mean that the above pattern (a list comprehension / map over an iterator, where a large array is accessed by an index) is not a good idea in Zygote – and is this emphasized anywhere? It seems like an easy mistake to make.

It’s the indexing that’s the problem. It doesn’t really notice that this is done inside a comprehension.

Naiively the gradient of y = x[i] needs to be roughly dx = zero(x); dx[i] = dy; dx, which allocates a whole array. Zygote now has a very simple sparse array for this, for one entry.

But the next problem is adding all the contributions dx. At present each step just does dx + dx2, which allocates, so I think it’s N^2.

PR#981 wanted to make it mutate instead, which in principle ought to make it O(N) I think. But in practice this only gets a factor of 2 here, so other things must be going wrong still. See also #644. This also might be solved other ways, e.g. by storing in-place-thunks & applying them all at the end.

None of this is well-documented, sadly. #1077 is looking for someone to start…

julia> @btime f(x) # 4.089 μs (3 allocations: 39.16 KiB)
  4.816 μs (3 allocations: 39.12 KiB)
13.44747151439001

julia> @btime gradient(f, x) 
  99.440 ms (30067 allocations: 764.81 MiB)   # latest, v0.6.29
  45.855 ms (20069 allocations: 383.19 MiB)   # with PR 981
([-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0  …  -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0],)

# Another way:
julia> using Tullio

julia> g(x) = @tullio _ := x[2i] - x[2i-1]
g (generic function with 1 method)

julia> @btime g(x)
  1.212 μs (1 allocation: 16 bytes)
13.447471514389736

julia> @btime gradient(g, x) 
  6.792 μs (21 allocations: 78.62 KiB)
([-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0  …  -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0],)

julia> x5 = randn(100_000);  # 10x

julia> @btime g($x5);
  11.666 μs (0 allocations: 0 bytes)

julia> @btime gradient(g, $x5);  # linear?
  53.333 μs (6 allocations: 781.36 KiB)

julia> @btime gradient(f, $x5);  # N^2 memory 
  41.749 s (200069 allocations: 37.27 GiB)

# A third way:
julia> h(x) = sum(view(reshape(x,2,:),2,:) - view(reshape(x,2,:),1,:));

julia> @btime h($x)
  3.104 μs (6 allocations: 39.30 KiB)
13.44747151439001

julia> @btime gradient(h, $x);
  13.875 μs (15 allocations: 195.81 KiB)

julia> @btime gradient(h, $x5);
  134.750 μs (15 allocations: 1.91 MiB)

2 Likes

That makes sense, thank you! I understand that AD is still in a state of flux, and it’s good to know what the current landscape is in terms of PRs. I’d love to contribute once I’ve gotten the hang of Julia a bit more.

I want to mention a nefarious case which I’m struggling to rewrite using maps, broadcasting, and so on (which has been my usual strategy for avoiding list comprehensions). I have an array of values, for example [1,2,3,4,5,6], and an array of lengths, for example, [2,1,3]. I’d like to turn the values into an array of arrays of the specified lengths, which would be [[1,2], [3], [4,5,6]] in the example, since this is the format needed for the next stage of the pipeline. But I have no idea how to do this without array indexing :frowning:

What’s frustrating is that this function is actually just a bijection; there’s no interesting adjoints going on at all. Yet I’m unable to think of a way of writing it that is fast (i.e. time does not increase by orders of magnitude in backwards pass)! Am I missing something here?

I think it really pays to write gradients for such functions by hand. The 5-minute version is something like this:

function divideinto(xs::AbstractVector, ls::AbstractVector{<:Integer})
  sum(ls) == length(xs) || throw("not a nice division!")
  i = 0
  [xs[i+1:(i+=l)] for l in ls]
end

divideinto([1,2,3,4,5,6], [2,1,3])

Zygote.@adjoint function divideinto(xs::AbstractVector, ls::AbstractVector)
  divideinto(xs, ls), dys -> (reduce(vcat, dys), nothing)
end

gradient(x -> sum(first, divideinto(x, [2,1,3])), [1,2,3,4,5,6])
1 Like