Suggestions to improve Zygote performance for simple vector map/broadcast/comprehension?

I’ve got some code which is building some moderate sized vectors, which I’d like to derive through. Here’s an example building these vectors with a comprehension (the issue is the same if I use map or broadcast):

build_vector(x) = [i<500 ? x : 0 for i=1:1000]
@btime gradient(x -> sum(build_vector(x)), 1) # ~2ms

This is on a perfromance critical inner loop, and it turns out this is ~1000 times slower than if I wrote the adjoint by hand,

build_vector_with_adjoint(x) = build_vector(x)
@adjoint function build_vector_with_adjoint(x)
    y = build_vector(x)
    function back(Δ)
        b = [i<500 ? 1 : 0 for i=1:1000]
        (b'Δ,)
    end
    y, back
end
@btime gradient(x->sum(build_vector_with_adjoint(x)), 1) # ~2μs

I don’t think I’m cheating too bad with this custom adjoint, it seems like this should basically be what Zygote should be writing for me. Profiling does show me some dynamic dispatch deep in the Zygote call-tree but I’m not familiar enough with the internals to make sense of it. The Zygote broadcast.jl source code has some comments alluding to performance hits and generic fallbacks, maybe I’m inadvertantly hitting something here? Any other suggestions to gain some performance without writing custom adjoints (which in my real non-MWE I think would be far more painful than here)? Thanks.

Here are a few ideas, depending on how close this is to your non-MWE:

julia> build_vector(x) = [i<500 ? x : zero(x) for i=1:1000];

julia> @btime Zygote.gradient(x -> sum(build_vector(x)), 1)
  2.490 ms (15573 allocations: 644.61 KiB)
(499,)

julia> @btime ForwardDiff.derivative(x -> sum(build_vector(x)), 1)
  1.470 μs (1 allocation: 15.75 KiB)
499

julia> @btime Zygote.gradient(x -> sum(Zygote.forwarddiff(build_vector,x)), 1)
  6.590 μs (28 allocations: 40.34 KiB)
(499,)

Thanks, thats helpful to see. In my non-MWE x is ~10 dimensional so I wanted to use reverse-mode, but maybe this still wins out, I can try it.

Is there a way to re-write the original MWE (a replacement/substitute for list comprehension) so that reverse-mode is still fast? I have a similar problem with the same MWE but a much more complicated list comprehension that has many parameters.

It depends on what code is running inside the comprehension, but probably. The key performance sink pitfall in the OP is that it uses control flow (conditionals and loops). Zygote isn’t able to generate efficient code for functions using control flow, so you’ll see both slower speeds and more allocations. Array comprehensions/map/broadcast with these functions is a worst-case scenario because it literally multiplies the overhead over the number of elements processed.

We can show the impact of removing control flow by using a branchless conditional (ifelse) instead of the ternary:

build_vector2(x) = [ifelse(i<500, x, 0) for i=1:1000]

julia> @btime gradient(x -> sum(build_vector(x)), 1);
  830.159 μs (6559 allocations: 285.09 KiB)

julia> @btime gradient(x -> sum(build_vector2(x)), 1);
  17.263 μs (44 allocations: 119.25 KiB)

However, some functions must use control flow. In that case, you have a few options:

  1. Use API · ChainRules around functions/code blocks that use control flow but don’t need to be differentiated.
  2. Define your own rrule(s) for functions that use control flow. The advice in Writing good rules · ChainRules applies as always, but one additional concern here is to make sure the type of the returned pullback function is stable. If it isn’t, you’ll run into many of the same issues as Zygote does.
2 Likes

Is there a good reference for code that Zygote can generate efficient code for?

That’s difficult to quantify because it’s mostly a subtractive thing (if you do X, things will be slower). Aside from things that just aren’t supported, the biggest performance pitfall I’ve seen outside of control flow is repeated indexing/view of a small section/single element of a large array. That allocates O(length(array)) on the backwards pass. That’s less of a codegen and more of a runtime perf issue, however. Accessing and setting properties on mutable structs within sight of Zygote will also be type unstable, though there you only pay for the cost of the dynamic dispatch (whereas generated pullbacks for control flow can allocate quite a bit more besides).

1 Like