Improving pullback speed with list comprehension

Here is a MWE of the function I have that needs AD.

N = 100
function compute_many_prods(f, g)
    return [f * g * i for i = 1:N]
end
function compute_many_prods2(f, g)
    return compute_many_prods(f, g)
end
function ChainRulesCore.rrule(::typeof(compute_many_prods),
                              f, g)
    pullbacks = [Zygote.pullback((f,g) -> f * g * i,
                                    f, g)[2] for i = 1:N]
    function many_prods_pullback(ΔΩ)
        df, dg = reduce((a,b) -> a .+ b, [pullbacks[i](ΔΩ[i]) for i = 1:N])
        return (ChainRulesCore.NoTangent(), df, dg)
    end
    return compute_many_prods(f, g), many_prods_pullback
end
f = 9; g = 12;
pb1 = Zygote.pullback(compute_many_prods, f, g)[2];
pb2 = Zygote.pullback(compute_many_prods, f, g)[2];
@btime pb1([7 for i = 1:N])
@btime pb2([7 for i = 1:N])

The results are 806.352 ns (15 allocations: 3.03 KiB) and 806.209 ns (15 allocations: 3.03 KiB).

In this post, it is suggested that control flow makes the pullback slow. I am trying to write a custom adjoint to prevent that. But it looks like this does not solve the problem.

Please note that in my problem I have a much more complicated function in the list comprehension than scalar multiplication.

@ToucheSir

Maybe I was unclear in that thread. The problem is control flow inside the body of the array comprehension and not with the comprehension itself. Notice how in Suggestions to improve Zygote performance for simple vector map/broadcast/comprehension? - #5 by ToucheSir I replace the ternary inside the [ ... for ... ] with an ifelse still inside the [ ... for ... ].

So you wouldn’t need a custom rule for compute_many_prods, but one for (f,g) -> f * g * i. Of course this particular example won’t be any faster because the f * g * i doesn’t use any control flow, so you’ll want a better MWE that does.

Ok, I’m afraid it may be impossible to remove the control flow. If I were to implement a pullback I would probably be writing exactly the same thing as Zygote is doing? Here is a much-less-minimal example:

using DynamicPolynomials
using Zygote
@polyvar x y;
N = 10;
function transvect(f::AbstractPolynomialLike, g::AbstractPolynomialLike, n)
    sum([differentiate(differentiate(f, x, i), y, n-i) *
        differentiate(differentiate(g, y, i), x, n-i) for i = 1:n])
end
function compute_many_diffs(f::AbstractPolynomialLike, g::AbstractPolynomialLike)
    return [transvect(f, g, n)  for n = 1:N]
end
f = (x + y)^2; g = (x + 2*y)^2;
pb = Zygote.pullback(compute_many_diffs, f, g)[2]

It is a little confusing because “differentiate” in DynamicPolynomials is with respect to the variables x or y, but I’m interested in optimizing over the coefficients of the polynomials. So I think of the coefficients of polynomials as the parameters.

Yes and no. Zygote has to be very conservative when it encounters branching control flow, because it doesn’t know which branch will be taken and/or how many times loops will run. It also has no knowledge of things such as function return types to help it generate optimal data structures for the backwards pass, so it’s forced to use a type unstable, allocation-heavy lowest common denominator approach.

Assuming differentiate refers to the functions defined here, I’m sure you could come up with more specific custom rules than the general-purpose ones Zygote generates. The main reason custom rules work is not because they get rid of control flow, but that they hide it from Zygote to avoid the suboptimal path described above.