Zygote Performance

At the moment I implement gradients for all my models, because I have not yet found that Julia’s AD packages come anywhere close to my fairly naive hand-optimised gradients. (The only exception is ForwardDiff on very small problems when used with StaticArrays!)

But every few months I explore what is new and try out some basic benchmarks. This time, I decided to give Zygote.jl a go; here are the results for taking the gradient of a map f : R^100 -> R.

# function evaluation:
  884.212 ns (0 allocations: 0 bytes)
# manual gradient
  1.928 μs (1 allocation: 896 bytes)
# manual gradient, with pre-allocated gradient storage
  1.771 μs (0 allocations: 0 bytes)
# naive Zygote : gradient(f, x)
  7.850 μs (333 allocations: 13.66 KiB)
# Zygote gradient with a "hack" (see main code below)
  2.738 μs (8 allocations: 1.11 KiB)

The full code for this hugely oversimplified, but still somewhat representative (I think) test is below. The good news: Zygote is only about a factor 4 slower than manual gradients in this example; in the past it was more a factor 20-30; I am really pleased with this and will start running more extensive tests now with some more realistic models. (this will take some time…). Even better: there is this hack which avoids some allocations (improves the adjoint for sum), and the performance now goes to about 1.5 of the manual implementation.

This raises some questions:

  • Is this “hack” indicative of the performance improvements that I can still expect?
  • Are there any “official” ways to maybe avoid memory allocations altogether? I’m happy managing the pre-allocation and using a less elegant interface.
  • Any other suggestions of what I should look at if I want performant gradients with Zygote?
using Zygote, BenchmarkTools
using Zygote:@adjoint

ρfun(r) = exp(-3*(r+1))
dρfun(r) = -3 * ρfun(r)

function eam(x)
   ρ = sum(ρfun, x)
   return ρ^2 + ρ^3 + 0.25 * ρ^4

function deam!(g, x)
   fill!(g, 0.0)
   N = length(x)
   ρ = sum(ρfun, x)
   dF = 2*ρ + 3*ρ^2 + ρ^3
   for n = 1:N
      g[n] = dρfun(x[n]) * dF
   return g

# --------------------------------------------------
# Using the workaround from
#   https://github.com/FluxML/Zygote.jl/issues/292
function sum2(op,arr)
	return sum(op,arr)
function sum2adj( Δ, op, arr )
	n = length(arr)
	g = x->Δ*Zygote.gradient(op,x)[1]
	return ( nothing, map(g,arr))
@adjoint function sum2(op,arr)
	return sum2(op,arr),Δ->sum2adj(Δ,op,arr)
function eam2(x)
   ρ = sum2(ρfun, x)
   return ρ^2 + ρ^3 + 0.25 * ρ^4
# --------------------------------------------------

# benchmark script
# ----------------

deam(x) = deam!(zeros(length(x)), x)
zeam(x) = gradient( eam, x )[1]
zeam2(x) = gradient( eam2, x )[1]

x = rand(100)
g = rand(100)
@show sqrt(sum((zeam(x) - deam(x)).^2))

@btime eam($x);
@btime deam($x);
@btime deam!($g, $x);
@btime zeam($x);
@btime zeam2($x);

I should also like to add that - although my model function has nothing to do with ANNs - its structure is so reminiscent of a NN, that I was very surprised that Zygote wasn’t fully optimised for it. In general, I’d be interested to learn what the difficulties are, if anybody can point me to the most important issues or code?

If I understand correctly, the AD ecosystem is still very much in flux (pun intended). There is an ongoing effort to switch to ChainRules (at https://github.com/FluxML/Zygote.jl/pull/291), which should improve adjoints for some computations (at least, that’s what @tkf says in https://github.com/JuliaNLSolvers/NLsolve.jl/issues/205), might be interesting revisiting this test once that’s done. But it’s definitely starting to work for non-trivial code!


I’m not sure about that. NNs are always really big O(n^3) operations. This looks like code which uses scalar operations instead of BLAS calls. It’s really easy to make code that is one giant matmul give a good adjoint fast, but it’s hard to generate a good adjoint code for a scalarized code.

So FWIW, 4x is really really good if I’m interpreting your code correctly: things like PyTorch and Tensorflow Eager has a per-op overhead of like 500ns, which corresponds to about a 100x-500x overhead on scalar operations. I saw 10x with Zygote last time I was looking, so if it’s down to 4 we should be happy. That last 4x might be due to something difficult like heap allocation of the forward pass.


As I said - I’m really pleased about 4x - but it’s actually nowhere near enough for my application domain.

Ich see your point though the Memory Access is missing entirely.

1 Like

I just wouldn’t ever expect it to match a handcoded adjoint on a scalar code. It caches the values from the forward pass onto the heap, which will kill the speed of something that only does cheap operations.

(I do know of someone who is looking into a way around this though, but not necessarily with Zygote)


Hopefully I’m not going too far off topic by focusing just on optimization rather than auto diff, and I’m still working on getting my libraries to a point where they are well tested and documented enough that I’d feel good about registering and promoting them (probably at least half a year away), but…

using LoopVectorization
@generated function deamloop!(g::AbstractVector{T}, x::AbstractVector{T}) where {T}
        p = zero($T)
        @vvectorize $T for n in 1:length(g)
            pn = exp($T(-3.0)x[n]-$T(3.0))
            p += pn
            g[n] = pn
        eam = p^2 * muladd(muladd(0.25, p, 1), p, 1)
        dFn3 = -3p * muladd(p + 3, p, 2)
        @vvectorize $T for n in 1:length(g)
            g[n] = dFn3 * g[n]

This yields

julia> x = rand(100);

julia> g = rand(100);

julia> eam(x)

julia> deamloop!(g, x)

julia> g'
1×100 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}:
 -0.223236  -1.43648  -0.336083  -0.437743  -0.302808  -1.24111  -0.410472  -0.130023  -1.91381  -1.22319  -0.868343  -0.148723  -0.65901  -0.120098  -0.692188  -0.235971  -0.133548  -0.124561  -0.572754  -1.06415  …  -0.574718  -0.231109  -0.375294  -0.843012  -0.476337  -0.129101  -0.660548  -0.895864  -0.399504  -1.14285  -0.232456  -1.11698  -0.308696  -0.484962  -2.13588  -0.252342  -0.230118  -0.47731  -0.433896

julia> deam!(g, x)'
1×100 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}:
 -0.223236  -1.43648  -0.336083  -0.437743  -0.302808  -1.24111  -0.410472  -0.130023  -1.91381  -1.22319  -0.868343  -0.148723  -0.65901  -0.120098  -0.692188  -0.235971  -0.133548  -0.124561  -0.572754  -1.06415  …  -0.574718  -0.231109  -0.375294  -0.843012  -0.476337  -0.129101  -0.660548  -0.895864  -0.399504  -1.14285  -0.232456  -1.11698  -0.308696  -0.484962  -2.13588  -0.252342  -0.230118  -0.47731  -0.433896


julia> using BenchmarkTools

julia> @benchmark deam!($g, $x)
  memory estimate:  0 bytes
  allocs estimate:  0
  minimum time:     1.637 μs (0.00% GC)
  median time:      1.652 μs (0.00% GC)
  mean time:        1.726 μs (0.00% GC)
  maximum time:     4.660 μs (0.00% GC)
  samples:          10000
  evals/sample:     10

julia> @benchmark deamloop!($g, $x)
  memory estimate:  0 bytes
  allocs estimate:  0
  minimum time:     108.497 ns (0.00% GC)
  median time:      111.417 ns (0.00% GC)
  mean time:        113.874 ns (0.00% GC)
  maximum time:     278.487 ns (0.00% GC)
  samples:          10000
  evals/sample:     930

Note, this is on a cpu with avx512. Expect about half the performance advantage with avx2.

1 Like

Then how does the „hacked version“ work around it?

Thanks for sharing this code - very interesting to see this, even if it doesn’t help with AD.

@Elrod - can you point me to where I can learn more about what you are using Here?

Edit: sorry, saw the using statement (which I missed before)

Maybe I should single-out a specific question for now: is a way - now or in the future - to prevent all memory allocation in a call to an AD-gradient. (By providing work arrays.)

The “hack” is just defining the analytical solution to the gradient of a sum:

@adjoint function sum2(op,arr)
	return sum2(op,arr),Δ->sum2adj(Δ,op,arr)

so that way it’s just doing adjoints on array operations, increasing the operation cost and reducing the allocations.

1 Like


@adjoint function sum(op,arr)
	return sum(op,arr),Δ->sum2adj(Δ,op,arr)

should be added to Flux.

Thanks - I understand what it does. I don’t understand what fundamental limitation prevents this being already implemented and other similar optimisations.

Essentially what’s going on is Flux is generating a function that can be called at any time in the future, which encloses values of the forward pass because they will be reused in the backpass. The question is then whether those values from the forward pass can be stack-allocated if they are simple isbits types. Since the adjoint function is a function (source-to-source is input value independent), it cannot build the values of the trace into the built code for the backpass (since there is no trace). But if you build a the source translation via a standard Julia closure, it will stack-allocate the enclosed values since that’s what it does with references. And that’s really the only safe way to do it, since a stack is function-local, so when you do the next function call you need some way to pull the values in.

But if you are going to difference right after you do the forward pass, you could in theory do something that is memory-unsafe and check the part of memory where you’d expect the value from the other stack to have been (IIUC), and that is something you could do at the LLVM level. But that kind of trick isn’t generally accessible in Julia, and requires some extra assumptions on how it’s being used.


One optimization opportunity daem! is not using is that it evaluates exp on the same value twice; once in ρfun and once in dρfun. I think this is something AD libraries are really good at. Every rule is written/reviewed by experts so it embeds knowledge like this which you may accidentally forget.

So maybe Zygote can be fast? I suspected the inefficiency in the adjoint definition of sum(f, xs) so the first version I tried was this:

function eam3(x)
    y = exp.(.- 3 .* (x .+ 1))
    ρ = sum(y)
    return ρ^2 + ρ^3 + 0.25 * ρ^4

zeam3(x) = gradient(eam3, x)[1]

… which didn’t work (see benchmarks below). I think this is because Zygote turns lazy broadcasting in eager broadcasting so that memory allocation is overwhelming.

So I tried using ForwardDiff.jl-based definition of @adjoint of broadcast. This is available through my library ChainCutters.jl but only through BroadcastableStructs.jl at the moment:

using BroadcastableStructs: BroadcastableCallable
import ChainCutters

struct RhoFun <: BroadcastableCallable

(::RhoFun)(x) = ρfun(x)

function eam4(x)
    ρ = sum(RhoFun().(x))
    return ρ^2 + ρ^3 + 0.25 * ρ^4

zeam4(x) = gradient(eam4, x)[1]

With this, Zygote version is comparable to the hand written version:

@btime eam($x);          # 825.692 ns (0 allocations: 0 bytes)
@btime deam($x);         # 1.810 μs (1 allocation: 896 bytes)
@btime deam!($g, $x);    # 1.686 μs (0 allocations: 0 bytes)
@btime zeam($x);         # 7.599 μs (333 allocations: 13.66 KiB)
@btime zeam2($x);        # 2.443 μs (8 allocations: 1.11 KiB)
@btime zeam3($x);        # 7.568 μs (345 allocations: 14.36 KiB)
@btime zeam4($x);        # 1.945 μs (22 allocations: 4.02 KiB)

But it still does not show the improvement over deam due to duplicated evaluation of exp. To see this, I needed to bump up the input size:

x = rand(1000)
g = rand(1000)

@btime eam($x);          # 7.938 μs (0 allocations: 0 bytes)
@btime deam($x);         # 17.662 μs (1 allocation: 7.94 KiB)
@btime deam!($g, $x);    # 16.274 μs (0 allocations: 0 bytes)
@btime zeam($x);         # 62.225 μs (3035 allocations: 126.13 KiB)
@btime zeam2($x);        # 23.422 μs (8 allocations: 8.17 KiB)
@btime zeam3($x);        # 63.060 μs (3045 allocations: 127.03 KiB)
@btime zeam4($x);        # 11.130 μs (22 allocations: 32.13 KiB)

Now zeam4 is faster than deam! :slight_smile:


Actually, don’t use my libraries. The best solution probably is

@adjoint function Broadcast.broadcasted(::typeof(ρfun), x::Union{T,AbstractArray{<:T}}) where {T<:Real}
    y, back = Zygote.broadcast_forward(ρfun, x)
    y, ȳ -> (nothing, back(ȳ)...)

Then zeam is faster than deam (when length(x) == 1000):

@btime deam($x);  # 17.746 μs (1 allocation: 7.94 KiB)
@btime zeam($x);  # 11.919 μs (33 allocations: 32.33 KiB)

This works because sum(f, xs) is defined as sum(f.(xs)) in Zygote:

while the generic broadcasted definition is known to have the performance problem (see the comments in the source code: https://github.com/FluxML/Zygote.jl/blob/ea4d1e894af775a6783f683d851bcfb5c10e25b0/src/lib/broadcast.jl#L94-L105).

I think it might be better if Zygote switches to use broadcast_forward as the default fallback (see also https://github.com/JuliaDiff/ChainRules.jl/issues/12#issuecomment-483058007). This should be reasonable as long as you don’t have an output size much smaller than “effective” input size (which can happen, e.g., when you have Ref(huge_array) in the input).


many thanks for the details - will take me a while to digest this :).

This is a very nice workaround, even on small input sizes it gets into the right ball part. Unfortunately it is only a solution for my one very specific case I think? For my next model, I’ll have to start from scratch producing workarounds like this.

Also - I should say, my input size 100 was chose quite deliberately. I don’t call this model (a constitutive law if you wish) once but many times (million-billion calls). I should have clarified this earlier.

Thank you everybody for engaging with my question. I just want to try to summarise what I think the outcome is:

Is this “hack” indicative of the performance improvements that I can still expect?

Possibly, but not clear due to technical limitations / design choices. But wait for merge of #291

Are there any “official” ways to maybe avoid memory allocations altogether? I’m happy managing the pre-allocation and using a less elegant interface.

No, and unlikely this is possible.

Incidentally, I just tried to use Zygote with Vector{SVector} input, and the performance completely falls apart. But I’ll post about that in a new thread, once I’ve tried for a while by myself.


I think this is a very generic “workaround” and it even makes sense to built into Zygote. If you want to have fast out-of-the-box broadcasting, maybe post it in Zygote’s issue tracker?

Until Zygote has it, maybe you can do something like

abstract type FastBroadcastable end

@adjoint function Broadcast.broadcasted(f::FastBroadcastable, args::Union{Real,AbstractArray{<:Real}}...)
    y, back = Zygote.broadcast_forward(f, args...)
    y, ȳ -> (nothing, back(ȳ)...)

(untested) and then define your functions as callables

struct RhoFun <: FastBroadcastable end
const ρfun = RhoFun()

ρfun(r) = exp(-3*(r+1))