In the following code I calculate derivative fx, then derivative of derivative fxy, and then fxyz.
The code works and gives the correct result, however the compilation time is huge for the fxyz derivative. Is there a way to improve this behavior?
using Zygote
# the function of interest
f(x, y, z) = x*y*z
# zygote gradients
fx(x, y, z) = gradient(x->f(x, y, z), x)[1]
fxy(x, y, z) = gradient(y->fx(x, y, z), y)[1]
fxyz(x, y, z) = gradient(z->fxy(x, y, z), z)[1]
@time fx(1.f0, 2.f0, 3.f0)
@time fxy(1.f0, 2.f0, 3.f0)
@time fxyz(1.f0, 2.f0, 3.f0)
@code_warntype fx(1.f0, 2.f0, 3.f0)
@code_warntype fxy(1.f0, 2.f0, 3.f0)
@code_warntype fxyz(1.f0, 2.f0, 3.f0)
Timing output:
0.000003 seconds
10.598413 seconds (21.78 M allocations: 1.099 GiB, 3.89% gc time, 99.75% compilation time)
183.757771 seconds (274.62 M allocations: 13.275 GiB, 2.83% gc time, 98.93% compilation time)
If we look at the warntype outputs we can see:
for fx: Main.gradient(%6, x)::Tuple{Float32}
for fxy: Main.gradient(%6, y)::Union{Nothing, Tuple{Any}}
for fxyz: Main.gradient(%6, y)::Union{Nothing, Tuple{Any}}
For fx the function correctly returns Tuple{Float32} but for the other we get Union{Nothing, Tuple{Any}}. Can we make fxy return Tuple{Float32} as well?