Zygote terribly slow. What am I doing wrong?

Hi all,

I have been using ReverseDiff.jl for a project for quite some time. Now I have switched to Zygote because I need to determine derivative w.r.t. custom types and this seems straightforward with Zygote (see Evaluate gradient (backpropagation) on a user defined type).

Now the problem is that Zygote seems to be at least one order of magnitude slower than ReverseDiff.jl. Here a MWE:

import Zygote
import ReverseDiff
using BenchmarkTools

# Prepare some random data
ns = 20
model(x,p) = p[1] + p[2]*x
xs = randn(ns)
ys = 1.0 .+ 2.0.*xs .+ 0.1 .* randn(ns)

# Simple loss function 
function loss(p, x, y)

    f = (y[1] - model(x[1], p))^2/0.01
    for i in 2:length(x)
        f = f + (y[i] - model(x[i], p))^2/0.01

    return f

p = randn(2)
res1 = similar(p)
res2 = similar(p)

println("## ")
println("# ReverseDiff.jl Using optimized ReverseDiff")
println("## ")
cf_tape = ReverseDiff.compile(ReverseDiff.GradientTape(p -> loss(p,xs,ys), (p)))
@btime ReverseDiff.gradient!(res1, cf_tape, p)

println("## ")
println("# Zygote.jl ")
println("## ")
@btime res2 .= Zygote.gradient(p -> loss(p,xs,ys), p)[1]

println("## ")
print("# Check (should be zero): ")
println("## ")


# ReverseDiff.jl Using optimized ReverseDiff
  2.452 μs (0 allocations: 0 bytes)
# Zygote.jl 
  35.200 μs (919 allocations: 42.83 KiB)
# Check (should be zero): 0.0

Is there any way to record the computational graph to compute gradients of a specific function (similar to the tape of ReverseDiff.jl) in Zygote? I suspect that I should be doing something wrong, but cannot find information in the documentation.


This is Zygote’s least favourite type of function, sadly. Indexing in a loop means it has to accumulate gradients for the whole x each time.

There is no special GradientTape concept. I don’t know ReverseDiff so well; it may be able to avoid some of this even without the compilation, as by tracing forwards it knows that you do not care about the gradient of x, only p; whereas Zygote starts at the end and does not know this.

For this problem, ForwardDiff is even faster, under 400ns for me.

1 Like


Thanks for the comments. Unfortunately the real case has O(500) parameters, and then I suspect that forward mode AD will also be significantly slower (it will require O(500) passes). Moreover ForwardDiff.jl has also the problem that it is difficult to evaluate the derivatives on custom types…

I am new to Julia for these things, and I have to say that I am surprised. Given the preference of Zygote.jl in ML applications, this kind of loss functions (with a loop over the data), is basically the cornerstone of any least squares/Bayesian Inference problem. Do you happen to know the approach taken by Flux.jl. It just lives with the allocations and the slower speed?


Maybe O(500/12) passes: ForwardDiff runs a chunk of different sensitivities forwards in one go. 500 is not a hopeless size to try it.

I’m not quite sure I know what you have in mind. The classic Flux thing is to loop over the data (in batches), and do one gradient evaluation per batch, then update the parameters. So the loop over data is outside the AD.

There are many other things you can do, but it’s hard to guess. If you sketch what the function looks like maybe there is more to say.

The function is very similar to this loss function that I have written, except that the model function is a much more complicated function of O(500) parameters. Of course I can do part of the derivatives by hand… I could even compute the gradient analytically, but this is just what AD is supposed to do for us.

Moreover, the real challenge is that I need to evaluate these derivatives for p being a custom type, with my own defined operations *,+,-,.... This last part is tricky both with ForwardDiff and ReverseDiff. I guess I will need to understand what ReverseDiff.jl is doing so that it accepts my own datatype.

The more you can write it to use whole-array operations, the better. For instance you could easily unpack p here to variables once and use those in a loop, that will be better. But I hope you don’t have 500 completely distinct scalar parameters, so keep whatever structure there is together. You could probably broadcast over x not iterate.

Enzyme is pretty fast for these kinds of loss functions.

julia> import Enzyme

julia> res3 = zero(p)
2-element Vector{Float64}:

julia> Enzyme.autodiff(loss, Enzyme.Duplicated(p, res3), xs, ys)

julia> res3 ≈ res2

julia> @btime Enzyme.autodiff($loss, Enzyme.Duplicated($p, $res3), $xs, $ys)
  648.038 ns (1 allocation: 16 bytes)


Will have a look. Is there a way to perform derivatives with respect to custom types?


Do you mean something like

mutable struct X
Base.zero(::X) = X(0, 0)
p = X(0, 0)
model(x,p) = p.a + p.b*x
res3 = zero(p)
Enzyme.autodiff(loss, Enzyme.Duplicated(p, res3), xs, ys)



I mean

model(x, p) = p[1]*p[2] + sin(p[2]*x)


grad_model(x, p) = [p[2], p[1] + x*cos(p[2]*x)]

I want grad_model to be evaluated if eltype(p) == MyType, with the operations between MyType defined by me.


I’m not sure I understand what you want to achieve. Can you maybe elaborate?

Is MyType a special kind of real number that you want to define yourself?

Or are you looking for an easy way to switch between different models, like

model(::Linear, x, p) = p[1] + p[2]*x
model(::NonLinear, x, p) = p[1]*p[2] + sin(p[2]*x)



What I am trying is to evaluate the expression of a gradient (obtained with the usual rules of differentiation) with my own struct

For example, if I have

struct foo
Base.:*(a::foo,b::foo) = foo(a.x^2+b.x^2, a.y+b.y)
Base.:*(a::Number,b::foo) = foo(a^2*b.x^2, a+b.y)


f(x) = x*x 

I want to get df(x) given x = foo(1.0,2.3). Since df(x) = 2*x, one should get foo(4.0,4.3). For this what I am trying is to get a package that does AD, cheat and say that foo <: Real, and then evaluate the graph of backpropagation using my own definitions. Zygote.jl can do this, but unfortunately is slow. I am having some problems with ReverseDiff.jl, that up to now has served me well, since one is forced to include foo in several places…

Ah, I see, sorry for the dumb question and thanks for the clarification! Unfortunately I don’t know if Enzyme will work with this.