[Zygote] Derivative of derivative compilation time

In the following code I calculate derivative fx, then derivative of derivative fxy, and then fxyz.
The code works and gives the correct result, however the compilation time is huge for the fxyz derivative. Is there a way to improve this behavior?

using Zygote

# the function of interest
f(x, y, z) = x*y*z

# zygote gradients
fx(x, y, z) = gradient(x->f(x, y, z), x)[1]
fxy(x, y, z) = gradient(y->fx(x, y, z), y)[1]
fxyz(x, y, z) = gradient(z->fxy(x, y, z), z)[1]

@time fx(1.f0, 2.f0, 3.f0)
@time fxy(1.f0, 2.f0, 3.f0)
@time fxyz(1.f0, 2.f0, 3.f0)

@code_warntype fx(1.f0, 2.f0, 3.f0)
@code_warntype fxy(1.f0, 2.f0, 3.f0)
@code_warntype fxyz(1.f0, 2.f0, 3.f0)

Timing output:

0.000003 seconds
10.598413 seconds (21.78 M allocations: 1.099 GiB, 3.89% gc time, 99.75% compilation time)
183.757771 seconds (274.62 M allocations: 13.275 GiB, 2.83% gc time, 98.93% compilation time)

If we look at the warntype outputs we can see:
for fx: Main.gradient(%6, x)::Tuple{Float32}
for fxy: Main.gradient(%6, y)::Union{Nothing, Tuple{Any}}
for fxyz: Main.gradient(%6, y)::Union{Nothing, Tuple{Any}}

For fx the function correctly returns Tuple{Float32} but for the other we get Union{Nothing, Tuple{Any}}. Can we make fxy return Tuple{Float32} as well?

2 Likes

Maybe try GitHub - JuliaDiff/ForwardDiff.jl: Forward Mode Automatic Differentiation for Julia or GitHub - JuliaDiff/TaylorDiff.jl: Taylor-mode automatic differentiation for higher-order derivatives … you don’t really need reverse-mode AD for tiny numbers of parameters, and forward mode is much less demanding of the compiler.

In particular, I tried your code adapted to ForwardDiff.derivative and it is basically instantaneous (@time reports 0.000000 seconds even on the first call).

1 Like

Note that you might need @time @eval ... because otherwise things might compile before the timing starts.

I think I have chosen an over simplified example. I have a Flux NN and that’s why Zygote is of interest here. I will create another MWE.

Still I’m wondering if we can make the above functions compile and run faster, using only Zygote.

At the end of the day, I have used the recommended ForwardDiff over Zygote approach. Here is an over-simplified MWE. This solution also works nicely for my Flux network. The Zygote over Zygote does not compile for my Flux network due to mutation errors, which can be expected according to the documentation: Limitations · Zygote.

using ForwardDiff
using Zygote

f(x, y, z) = x*y*z
fx(x, y, z) = Zygote.gradient(x->f(x, y, z), x)[1]

a = Float32[1, 2, 3]

fx(a...)
fx_yz(x, y, z) = ForwardDiff.gradient(yz->fx(x, yz...), [y, z])
fx_yz(a...)