At the moment I implement gradients for all my models, because I have not yet found that Julia’s AD packages come anywhere close to my fairly naive hand-optimised gradients. (The only exception is ForwardDiff
on very small problems when used with StaticArrays
!)
But every few months I explore what is new and try out some basic benchmarks. This time, I decided to give Zygote.jl
a go; here are the results for taking the gradient of a map f : R^100 → R.
# function evaluation:
884.212 ns (0 allocations: 0 bytes)
# manual gradient
1.928 μs (1 allocation: 896 bytes)
# manual gradient, with pre-allocated gradient storage
1.771 μs (0 allocations: 0 bytes)
# naive Zygote : gradient(f, x)
7.850 μs (333 allocations: 13.66 KiB)
# Zygote gradient with a "hack" (see main code below)
2.738 μs (8 allocations: 1.11 KiB)
The full code for this hugely oversimplified, but still somewhat representative (I think) test is below. The good news: Zygote
is only about a factor 4 slower than manual gradients in this example; in the past it was more a factor 20-30; I am really pleased with this and will start running more extensive tests now with some more realistic models. (this will take some time…). Even better: there is this hack which avoids some allocations (improves the adjoint for sum
), and the performance now goes to about 1.5 of the manual implementation.
This raises some questions:
- Is this “hack” indicative of the performance improvements that I can still expect?
- Are there any “official” ways to maybe avoid memory allocations altogether? I’m happy managing the pre-allocation and using a less elegant interface.
- Any other suggestions of what I should look at if I want performant gradients with Zygote?
using Zygote, BenchmarkTools
using Zygote:@adjoint
ρfun(r) = exp(-3*(r+1))
dρfun(r) = -3 * ρfun(r)
function eam(x)
ρ = sum(ρfun, x)
return ρ^2 + ρ^3 + 0.25 * ρ^4
end
function deam!(g, x)
fill!(g, 0.0)
N = length(x)
ρ = sum(ρfun, x)
dF = 2*ρ + 3*ρ^2 + ρ^3
for n = 1:N
g[n] = dρfun(x[n]) * dF
end
return g
end
# --------------------------------------------------
# Using the workaround from
# https://github.com/FluxML/Zygote.jl/issues/292
function sum2(op,arr)
return sum(op,arr)
end
function sum2adj( Δ, op, arr )
n = length(arr)
g = x->Δ*Zygote.gradient(op,x)[1]
return ( nothing, map(g,arr))
end
@adjoint function sum2(op,arr)
return sum2(op,arr),Δ->sum2adj(Δ,op,arr)
end
function eam2(x)
ρ = sum2(ρfun, x)
return ρ^2 + ρ^3 + 0.25 * ρ^4
end
# --------------------------------------------------
# benchmark script
# ----------------
deam(x) = deam!(zeros(length(x)), x)
zeam(x) = gradient( eam, x )[1]
zeam2(x) = gradient( eam2, x )[1]
x = rand(100)
g = rand(100)
@show sqrt(sum((zeam(x) - deam(x)).^2))
@btime eam($x);
@btime deam($x);
@btime deam!($g, $x);
@btime zeam($x);
@btime zeam2($x);