I am trying to use Zygote in order to avoid the need of manually passing gradients/Hessians. I started by the simplest possible example for my applications, namely a hydrogen atom. The Hamiltonian and initial point are defined as
I don’t think you’re doing anything wrong, Zygote is known to not be super fast for code which does lots of scalar indexing like yours. All the compiler magic going on often makes inference fail on the gradient calculation, so the alocations you’re seeing are probably because the gradient code ends up being type unstable. For gradients w.r.t. a small number of parameters like this I would just recommend using ForwardDiff. You can also use Zygote.forwarddiff to temporarily switch to ForwardDiff,
H(X) = Zygote.forwarddiff(X) do X
X^2/2.0 + X^2/2.0 + X^2/2.0 - ( X^2 + X^2 + X^2 )^-0.5
now you can embed H(X) in some bigger calculation if you wanted to still use Zygote on that, and this piece of it will use ForwardDiff / be fast.
Why did you choose Zygote in the first place? If I would be working with Flux, I’d be working with Zygote, obviously. Otherwise I’d check AD, Finite Differences, Symbolics.jl or manually derive and decide what is appropriate for my use case (I think I’ve seen technical limits for most of the mentioned system in one or the other way, especially w.r.t. complex numbers for example).
Edit: and (obviously?) you can use all these to validate correctness.
The 1.7.2> prompts are really annoying, as they prevent me from copy/pasting code.
The default julia> prompts would automatically be stripped upon pasting into a REPL.
Also, adding @inline to G helps the gradient time.
Also, starting Julia with --math-mode=fast also helps a lot.
Starting Julia this way is not recommended, except in helping to identify optimization opportunities, e.g. that it’d be worth adding @fastmath support to ForwardDiff.
Currently, @fastmath with ForwardDiff does not actually set any fast flags. All it does is prevent code from inlining, pessimizing it.
I’ll keep passing gradients, apparently. I still wonder, however, how people like @ChrisRackauckas were able to use ForwardDiff in their super efficient libraries. In DifferentialEquations, I know that for symplectic integrators one can provide the solver with the Hamiltonian only, and gradients are calculated by automatic differentiation. This is a mystery to me, since even for the gradient function provided by DNF, ForwardDiff is still 8x slower than writing the gradient manually. Maybe people behind DiferentialEquations chose to sacrifice efficiency in order to spare the user from writing the gradients, but the efficiency loss when one is propagating an ensemble of particles seems to be relevant (although, of course, the bottleneck of ODE solving is definitely happening somewhere else).
I think for many physical functions it’s best to derive them by hand, if possible. Then optimize them for the computer (like DNF did, avoid divisions, repeated calculations and non-integer exponents) and write a ChainRules rrule. That way Zygote will be able to work with it and it will be relatively fast.
I’m not sure the following is 100% correct. Improves from 4.3us to 735ns (gradient(K, $x) gives 920ns).
# define before first call to gradient
function ChainRulesCore.rrule(::typeof(G), x)
pullback(Δy) = (NoTangent(), ∇G(x) * Δy)
return G(x), pullback
Maybe someone more familiar with Zygote + ChainRules can optimize this even more.
Oh and while marius311’s answer above involving forwarddiff is super fast, it changes the definition of H. If that is not an option and you still want to use forwarddiff explicitly, you can use the freshly announced ForwardDiffPullbacks package:
gradient(fwddiff(G), x) # 133ns vs 730ns on my machine, 0 allocs
You cannot automatically apply it, but it will outperform AD in a lot of cases when you can. So we document the ModelingToolkit/Symbolics tools necessary to achieve the top-notch performance, and leave it to the user to choose the right path.
Julia strips the prompt and output so that you can copy/ past multiple lines.
So, no, it’s not easy, as the point is to copy multiple lines. E.g., if these were julua>s, I could use a single copy/paste for all the lines:
There’s still some room for improvement there, since ForwardDiffPullbacks currently calculates the primal result and the pullback for every argument separately. I’ve been thinking about adding a mode that does everything in one go, for use cases that will evaluate all Thunks anyway.
But as @ChrisRackauckas noted, analytical derivatives via ModelingToolkit/Symbolics (if applicable for your problem) will typically outperform any AD.