Zygote dozens* of times slower than manually written function

I am trying to use Zygote in order to avoid the need of manually passing gradients/Hessians. I started by the simplest possible example for my applications, namely a hydrogen atom. The Hamiltonian and initial point are defined as

using StaticArrays, ForwardDiff, Zygote

H(X) = X[1]^2/2.0 + X[2]^2/2.0 + X[3]^2/2.0 - ( X[4]^2 + X[5]^2 + X[6]^2 )^-0.5
x0 = @SVector rand(6);

I then get

∇1H(X) = Zygote.gradient(H, X)
@btime ∇1H($x0);
>> 3.451 μs (57 allocations: 2.28 KiB)

Which is insanely slow and has unjustifiable allocations. Even if I use ForwardDiff directly, the result is far superior:

∇2H(X) = ForwardDiff.gradient(H, X)
@btime ∇2H($x0);
>> 171.613 ns (0 allocations: 0 bytes)

Still, it is twice what I get from a manual evaluation of the gradient:

∇H(X) = @SVector [
    X[1] ,
    X[2] ,
    X[3] ,
    X[4] * ( X[4]^2 + X[5]^2 + X[6]^2 )^-1.5 , 
    X[5] * ( X[4]^2 + X[5]^2 + X[6]^2 )^-1.5 ,
    X[6] * ( X[4]^2 + X[5]^2 + X[6]^2 )^-1.5 ,
]

@btime ∇H($x0);
>> 71.700 ns (0 allocations: 0 bytes)

Am I doing something wrong when using Zygote? Should I just keep passing gradients manually?

Well, 3 µs is not that slow… If I would have to do it manually I would probably need 3h…

I guess handwritten derivatives will always be faster to execute than machine generated derivatives, if, and this is the big if it is feasible to write them manually.

Automated differentation uses dual numbers which always causes some overhead.

I don’t think you’re doing anything wrong, Zygote is known to not be super fast for code which does lots of scalar indexing like yours. All the compiler magic going on often makes inference fail on the gradient calculation, so the alocations you’re seeing are probably because the gradient code ends up being type unstable. For gradients w.r.t. a small number of parameters like this I would just recommend using ForwardDiff. You can also use Zygote.forwarddiff to temporarily switch to ForwardDiff,

H(X) = Zygote.forwarddiff(X) do X
    X[1]^2/2.0 + X[2]^2/2.0 + X[3]^2/2.0 - ( X[4]^2 + X[5]^2 + X[6]^2 )^-0.5
end

now you can embed H(X) in some bigger calculation if you wanted to still use Zygote on that, and this piece of it will use ForwardDiff / be fast.

4 Likes

Why did you choose Zygote in the first place? If I would be working with Flux, I’d be working with Zygote, obviously. Otherwise I’d check AD, Finite Differences, Symbolics.jl or manually derive and decide what is appropriate for my use case (I think I’ve seen technical limits for most of the mentioned system in one or the other way, especially w.r.t. complex numbers for example).

Edit: and (obviously?) you can use all these to validate correctness.

1 Like

Just to hammer this point home, Zygote will allocate at least 1 full-sized array for every indexing operation in that function.

1 Like

Contrary to the title, in the original post Zygote is only 50x slower, not hundreds.

1 Like

We could defuse the title of the posting, should we?

I don’t know about Zygote, but you should definitely clean up your function, first of all. X^-0.5 is a really bad operation to do.

Here’s a comparison:

H(X) = X[1]^2/2.0 + X[2]^2/2.0 + X[3]^2/2.0 - ( X[4]^2 + X[5]^2 + X[6]^2 )^-0.5
G(X) = (X[1]^2 + X[2]^2 + X[3]^2)/2 - 1/sqrt(X[4]^2 + X[5]^2 + X[6]^2)
∇2H(X) = ForwardDiff.gradient(H, X)
∇2G(X) = ForwardDiff.gradient(G, X)

Benchmarks:

1.7.2> x0 = @SVector rand(6);

1.7.2> @btime H($x0)
  30.071 ns (0 allocations: 0 bytes)
-0.2515642893020753

1.7.2> @btime G($x0)
  3.100 ns (0 allocations: 0 bytes)
-0.2515642893020753

1.7.2> @btime ∇2H($x0);
  91.802 ns (0 allocations: 0 bytes)

1.7.2> @btime ∇2G($x0);
  22.211 ns (0 allocations: 0 bytes)

This probably won’t help with Zygote, but it’s good to improve the basics first.

3 Likes

And then there’s the manual gradient:

function ∇G(X)
    x = 1/(X[4]^2 + X[5]^2 + X[6]^2)
    y = x * sqrt(x)
    @SVector [
    X[1] ,
    X[2] ,
    X[3] ,
    X[4] * y, 
    X[5] * y,
    X[6] * y]
end

Benchmark:

1.7.2> @btime ∇H($x0);
  34.884 ns (0 allocations: 0 bytes)
        
1.7.2> @btime ∇G($x0);
  3.300 ns (0 allocations: 0 bytes)
3 Likes

The 1.7.2> prompts are really annoying, as they prevent me from copy/pasting code.
The default julia> prompts would automatically be stripped upon pasting into a REPL.

Also, adding @inline to G helps the gradient time.

Also, starting Julia with --math-mode=fast also helps a lot.
Starting Julia this way is not recommended, except in helping to identify optimization opportunities, e.g. that it’d be worth adding @fastmath support to ForwardDiff.
Currently, @fastmath with ForwardDiff does not actually set any fast flags. All it does is prevent code from inlining, pessimizing it.

1 Like

Not sure what to do about that. I’m not planning to use julia>, it’s way too long without carrying useful information. I use either version> or jl>.

Prompt pasting never worked for me anyway (on any platform), so I didn’t really consider it.

Isn’t it easy enough to not copy the prompt? You wouldn’t want to copy the output anyway.

As far I can see from the answers, I can summarize what I learned as:

I’ll keep passing gradients, apparently. I still wonder, however, how people like @ChrisRackauckas were able to use ForwardDiff in their super efficient libraries. In DifferentialEquations, I know that for symplectic integrators one can provide the solver with the Hamiltonian only, and gradients are calculated by automatic differentiation. This is a mystery to me, since even for the gradient function provided by DNF, ForwardDiff is still 8x slower than writing the gradient manually. Maybe people behind DiferentialEquations chose to sacrifice efficiency in order to spare the user from writing the gradients, but the efficiency loss when one is propagating an ensemble of particles seems to be relevant (although, of course, the bottleneck of ODE solving is definitely happening somewhere else).

julia> G(X) = (X[1]^2 + X[2]^2 + X[3]^2)/2 - 1/sqrt(X[4]^2 + X[5]^2 + X[6]^2)
G (generic function with 1 method)

julia> @btime Zygote.gradient(G, $x0)
  3.874 μs (57 allocations: 2.27 KiB)
([0.8312830657406003, 0.42241553817070676, 0.029793466603427188, 0.685730593635655, 0.23572364881902477, 0.540322706597109],)

julia> @views K(X) = sum(abs2, X[1:3])/2 - 1/norm(X[4:6])
K (generic function with 1 method)

julia> @btime Zygote.gradient(K, $x0)
  475.292 ns (10 allocations: 624 bytes)
([0.8312830657406003, 0.42241553817070676, 0.029793466603427188, 0.685730593635655, 0.23572364881902477, 0.540322706597109],)

I think for many physical functions it’s best to derive them by hand, if possible. Then optimize them for the computer (like DNF did, avoid divisions, repeated calculations and non-integer exponents) and write a ChainRules rrule. That way Zygote will be able to work with it and it will be relatively fast.

I’m not sure the following is 100% correct. Improves from 4.3us to 735ns (gradient(K, $x) gives 920ns).

# define before first call to gradient
function ChainRulesCore.rrule(::typeof(G), x) 
    pullback(Δy) = (NoTangent(), ∇G(x) * Δy)
    return G(x), pullback
end

Maybe someone more familiar with Zygote + ChainRules can optimize this even more.

Oh and while marius311’s answer above involving forwarddiff is super fast, it changes the definition of H. If that is not an option and you still want to use forwarddiff explicitly, you can use the freshly announced ForwardDiffPullbacks package:

using ForwardDiffPullbacks
gradient(fwddiff(G), x)  # 133ns vs 730ns on my machine, 0 allocs

We allow for direct derivation by ModelingToolkit.jl which is symbolic.

https://diffeq.sciml.ai/stable/tutorials/faster_ode_example/#Automatic-Derivation-of-Jacobian-Functions

https://mtk.sciml.ai/dev/mtkitize_tutorials/modelingtoolkitize/

You cannot automatically apply it, but it will outperform AD in a lot of cases when you can. So we document the ModelingToolkit/Symbolics tools necessary to achieve the top-notch performance, and leave it to the user to choose the right path.

2 Likes

Julia strips the prompt and output so that you can copy/ past multiple lines.
So, no, it’s not easy, as the point is to copy multiple lines. E.g., if these were julua>s, I could use a single copy/paste for all the lines:

1.7.2> x0 = @SVector rand(6);

1.7.2> @btime H($x0)
  30.071 ns (0 allocations: 0 bytes)
-0.2515642893020753

1.7.2> @btime G($x0)
  3.100 ns (0 allocations: 0 bytes)
-0.2515642893020753

1.7.2> @btime ∇2H($x0);
  91.802 ns (0 allocations: 0 bytes)

1.7.2> @btime ∇2G($x0);
  22.211 ns (0 allocations: 0 bytes)

This works on Linux and Mac.

That never worked for me, so I guess I didn’t consider it.

I guess I’ll have to manually edit the prompts each time then :face_with_diagonal_mouth:

There’s still some room for improvement there, since ForwardDiffPullbacks currently calculates the primal result and the pullback for every argument separately. I’ve been thinking about adding a mode that does everything in one go, for use cases that will evaluate all Thunks anyway.

But as @ChrisRackauckas noted, analytical derivatives via ModelingToolkit/Symbolics (if applicable for your problem) will typically outperform any AD.