Zygote Performance (Again...)

On simple reduction tasks Zygote seems to perform exceptionally poorly, and I would be grateful to hear why and what alternative patterns I should use to avoid these problems? I’ve started implementing custom adjoints via rrule but if I do this for every little piece of code then I might as well implement all gradients myself and forget about Zygote.

julia> using BenchmarkTools, Zygote
       f(x) = mapreduce(xi -> xi^2, +, x)
       x = rand(100)
       Zygote.gradient(f, x)
       @btime f($x)
       @btime Zygote.gradient($f, $x)
  11.128 ns (0 allocations: 0 bytes)
  480.792 μs (3369 allocations: 191.02 KiB)

We can do a little better like this:

julia> using BenchmarkTools, Zygote
       _sq(xi) = xi^2
       f(x) = sum(_sq, x)
       x = rand(100)
       Zygote.gradient(f, x)
       @btime f($x)
       @btime Zygote.gradient($f, $x)
  14.658 ns (0 allocations: 0 bytes)
  18.477 μs (392 allocations: 12.30 KiB)

but why!?!

And even better like this

julia> using BenchmarkTools, Zygote
       f(x) = sum(x.^2)
       x = rand(100)
       Zygote.gradient(f, x)
       @btime f($x)
       @btime Zygote.gradient($f, $x)
  117.833 ns (1 allocation: 896 bytes)
  239.328 ns (3 allocations: 1.77 KiB)

But even that third example doesn’t come anywhere near the 20-40ns that I would expect here.

1 Like

I’m a little confused, seems like your second example is both the fastest and under 20ns?

Zygote the is microseconds not nano

it is 18 us not ns

Are you saying that Zygote is to be used in the ms and not ns regime? That’s Ok!

But what if my code is sufficiently complex that it evaluates the models in the ms to s regime but goes many layers deep into inner functions that evaluate in the 100ns regime and this is where I currently loose a factor 1000 and more in performance.

I’m working a lot with custom adjoints. Here, I really just want to understand what is going on.

My mistake, read it multiple times and somehow still missed that. Blame it on the early morning…

Back to the topic at hand, the long and short of it seems to be that there aren’t optimized adjoints in place for mapreduce-like functions (including sum(f, ...)) yet. This becomes pretty clear from looking at the currently defined adjoints. Higher-order functions in general are not well optimized for AD right now, so it’s no surprise that they perform terribly. You can verify this for yourself by pulling up GitHub - JuliaDebug/Cthulhu.jl: The slow descent into madness and stepping into the pullback.

This is usually not a problem in ML code because people are used to poor performance/no support in other AD systems. Hence workarounds like materializing the input before calling sum (what your 3rd example does) so that it hits an existing rule like ChainRules.jl/mapreduce.jl at 65833a19629ee890d30edc96b61bbdbc4e0da72f · JuliaDiff/ChainRules.jl · GitHub or Zygote.jl/array.jl at 12f5c1d75eeaa8c7a818f2db7f8d082956c00cac · FluxML/Zygote.jl · GitHub.

On the bright side, things are slowly changing for the better. Rule for sum(f, xs) by oxinabox · Pull Request #441 · JuliaDiff/ChainRules.jl · GitHub is the first step in getting better higher-order function AD working. I don’t remember if Diffractor (next gen AD system) will help with this, but it’s worth looking out for that as well. In the meantime, your options are to a) avoid higher-order functions, or b) look into alternative AD systems other than Zygote (e.g. ForwardDiff)


Good question. What should one expect? I found this: How fast is automatic differentiation? - Computational Science Stack Exchange

Anyone knows a better answer to that?

Maybe I should explain why I’m hunting for performance in the first place - it is because a collaborator has implemented a very similar model as mine (nothing to do with the the model above) in Tensorflow and claims his AD gradients are only a factor 1.5 slower than my hand-optimised ones which are I believe within a factor 2-3 from optimal.

I figured combining Zygote + ForwardDiff was my best chance at getting something like that in Julia. But maybe his claim is just bogus and I should test that myself.

1 Like

To my knowledge the TF AD doesn’t support higher-order functions at all. If/when you get more info on good comparative examples, feel free to start a thread here and we can take a look at it.

1 Like

You might also want to try ReverseDiff or Yota instead of Zygote in this case, they might be a bit faster

Sorry for the confusion. I was responding to the comment about the timing of the second one too tersely.

Reductions like this are one of the reasons Tullio exists. It’s not quite a factor 2 here, but at least it’s in ns:

julia> using Tullio, ForwardDiff, Zygote

julia> ft(x) = @tullio tot := x[i]^2  threads=false avx=false;

julia> @btime ft($x);
  12.596 ns (0 allocations: 0 bytes)

julia> @btime Zygote.gradient(ft, $x);
  86.598 ns (1 allocation: 896 bytes)

julia> 86.598 / 12.596

julia> ftd(x) = @tullio tot := x[i]^2  grad=Dual threads=false avx=false;

julia> @btime Zygote.gradient(ftd, $x);  # sometimes more efficient
  86.863 ns (1 allocation: 896 bytes)

julia> @btime Zygote.gradient($f, $x);  # first version with mapreduce, my computer
  442.250 μs (3469 allocations: 213.92 KiB)

julia> @btime f($x);
  17.451 ns (0 allocations: 0 bytes)

thanks for suggesting those alternative packages. I understand that there are many alternatives. My choice of Zygote was that this was going to be the future of AD and I cannot afford to invest in something that will be replaced in a few months.

But then I just learned that Zygote will be replaced with something called Diffractor?

So what should I be investing in? Will I be relatively safe if I focus my development around ChainRules and use Zygote or any of these others only sparingly?

1 Like

Thank you, that’s an important distinction I just haven’t made because I’m so used to these in Julia that I don’t even think about them anymore. But would you say that avoiding higher order functions will automatically lead to faster Zygote gradients?

1 Like

You shouldn’t have to invest. The AD systems (will) share a rule system, ChainRules.jl, so the modifications you make (i.e. added rules) would cross over to other systems in a nice way.


I would say any function where you have a cheap to compute callback that is called many, many times (e.g. sum(f, x)) doesn’t perform great right now relative to the forward pass. If f is expensive to compute or the number of elements is small, it’ll be less of an issue.

1 Like

Is there some kind of a performance guide for AD in general and Zygote in particular? Similar to the one for Julia in the manual?