Hi! I’ve been running into memory/runtime issues with Zygote when using list comprehensions. I’d like to ask a simple question. Suppose x
is an array of length n. When Zygote differentiates backwards through the following code
f(x) = sum([x[i] - x[i+1] for i in range 1:2:N])
,
does it use O(n^2) memory/time, or O(n)?
The argument for O(n^2) is that a list comprehension is essentially the application of the function f(i) = x[i] - x[i+1]
over 1:2:N
. For each entry in the list comprehension, Zygote will compute the sensitivity of the output with respect to the variables that f
is a closure of (namely, x
), which has size O(n). This leads to a total of O(n^2) memory. Essentially, Zygote does not realize that the Jacobian of the list comprehension is sparse.
The argument for O(n) is that Zygote does actually realize this – for example, by representing \Delta f for each entry with O(1) memory using a sparse vector with two non zero entries.
I haven’t been able to parse through the Zygote code well enough to understand which of these cases is true. More generally, my question is: will Zygote always be smart enough to not use order of magnitudes more memory than necessary when differentiating through list comprehensions? Or should I be wary of this issue, and try to use maps, broadcasting, etc. whenever possible?