Zygote custom adjoint has surprising performance effects

I am experimenting with custom adjoints in Zygote. In the following code, I defined an adjoint for the function gauss. Surprisingly, benchmarking gradient on gauss and a “wrapper” closure f around gauss shows that the gradient of f is 3-4 times faster than that of gauss. Before defining the custom adjoint, the two gradients have the same speed. After the definition, f’s gradient is 50% faster than before, while gauss’s gradient is 2 times slower.

Why does this happen, and how can I get optimal performance with custom adjoints consistently?

using Zygote
using BenchmarkTools

function gauss(x, μ, σ)
    y = (x-μ)/σ
    exp(-y^2/2) / (sqrt(2π)*σ)
end

using Zygote: @adjoint
@adjoint function gauss(x, μ, σ) 
    y = (x-μ)/σ
    e = exp(-y^2/2) / (sqrt(2π)*σ)
    function back(Δ)
        ey = e*y/σ # could pool e*Δ too
        (-ey * Δ, ey * Δ, e*(1 - y^2)/σ * Δ)
    end
    return e, back
end

x = randn()
μ = randn()
σ = exp(randn())
f(x, μ, σ) = gauss(x, μ, σ)
@btime gradient(f, $x, $μ, $σ)
@btime gradient(gauss, $x, $μ, $σ) # 3-4 times slower