I am experimenting with custom adjoints in Zygote. In the following code, I defined an adjoint for the function gauss
. Surprisingly, benchmarking gradient
on gauss
and a “wrapper” closure f
around gauss
shows that the gradient of f
is 3-4 times faster than that of gauss
. Before defining the custom adjoint, the two gradients have the same speed. After the definition, f
’s gradient is 50% faster than before, while gauss
’s gradient is 2 times slower.
Why does this happen, and how can I get optimal performance with custom adjoints consistently?
using Zygote
using BenchmarkTools
function gauss(x, μ, σ)
y = (x-μ)/σ
exp(-y^2/2) / (sqrt(2π)*σ)
end
using Zygote: @adjoint
@adjoint function gauss(x, μ, σ)
y = (x-μ)/σ
e = exp(-y^2/2) / (sqrt(2π)*σ)
function back(Δ)
ey = e*y/σ # could pool e*Δ too
(-ey * Δ, ey * Δ, e*(1 - y^2)/σ * Δ)
end
return e, back
end
x = randn()
μ = randn()
σ = exp(randn())
f(x, μ, σ) = gauss(x, μ, σ)
@btime gradient(f, $x, $μ, $σ)
@btime gradient(gauss, $x, $μ, $σ) # 3-4 times slower