I am experimenting with custom adjoints in Zygote. In the following code, I defined an adjoint for the function
gauss. Surprisingly, benchmarking
gauss and a “wrapper” closure
gauss shows that the gradient of
f is 3-4 times faster than that of
gauss. Before defining the custom adjoint, the two gradients have the same speed. After the definition,
f's gradient is 50% faster than before, while
gauss's gradient is 2 times slower.
Why does this happen, and how can I get optimal performance with custom adjoints consistently?
using Zygote using BenchmarkTools function gauss(x, μ, σ) y = (x-μ)/σ exp(-y^2/2) / (sqrt(2π)*σ) end using Zygote: @adjoint @adjoint function gauss(x, μ, σ) y = (x-μ)/σ e = exp(-y^2/2) / (sqrt(2π)*σ) function back(Δ) ey = e*y/σ # could pool e*Δ too (-ey * Δ, ey * Δ, e*(1 - y^2)/σ * Δ) end return e, back end x = randn() μ = randn() σ = exp(randn()) f(x, μ, σ) = gauss(x, μ, σ) @btime gradient(f, $x, $μ, $σ) @btime gradient(gauss, $x, $μ, $σ) # 3-4 times slower