Emulate Enzyme.Const with Zygote

In the next release, DifferentiationInterface.jl will accept constant (non-differentiated) arguments c in addition to the active (differentiated) argument x, I’m wondering how to implement this mechanism optimally for each backend.

For Enzyme, it’s easy, I can just use Const(c). For Zygote, there is no built-in solution, but I have found 3 possible approaches. To compute the partial pullback of a multi-argument y = f(x, c) with respect to x only, I can either

  1. Compute the full pullback of f with respect to (x, c) and discard the second part
  2. Compute the full pullback of the single-argument Base.Fix2(f, c)
  3. Define a closure g which manually drops the gradient (using @CarloLucibello’s Zygote.@nograd trick)
Demo code
using Zygote: Zygote

dropgrad(x) = x
Zygote.@nograd dropgrad

function mypullback_full(f, x, c)
    _, pb = Zygote.pullback(f, x, c)
    return first ∘ pb
end

function mypullback_fix(f, x, c)
    _, pb = Zygote.pullback(Base.Fix2(f, c), x)
    return only ∘ pb
end

function mypullback_dropgrad(f, x, c)
    g(x, c) = f(x, dropgrad(c))
    _, pb = Zygote.pullback(g, x, c)
    return first ∘ pb
end

My question is: which is preferrable in general? And in the specific case of neural networks like Lux.jl (ping @avikpal) or Flux.jl?
Note that I am interested in both the time it takes to compute the pullback, and the time it takes to apply the resulting closure.

Ideally, 1 would be sufficient because the thunking mechanism means useless cotangents (here dy/dc) never get materialized. In practice, @oxinabox once told me that Thunks aren’t actually used, so maybe this solution is wasteful?
Intuitively, I would go for 2 and use Base.Fix2 because Zygote.pullback doesn’t compute the cotangent of the function object itself (unlike ChainRulesCore.rrule_via_ad). Thus, hiding the c inside the function seems like a good option?
In any case, 3 seems suboptimal due to the additional closure, and the fact that dy/dc is computed anyway, and only dropped at the end. But I’d love to have some confirmation from people familiar with Zygote.

Here’s some benchmarking code if you want to play around. I didn’t observe any meaningful differences but my test case is too simple to matter.

Benchmarking code
using BenchmarkTools

f(x, c) = sum(sin.(x) .* cos.(c))  # change this
x, c = rand(10), rand(10);  # change this

pb_full = @btime mypullback_full($f, $x, $c);
pb_fix = @btime mypullback_fix($f, $x, $c);
pb_dropgrad = @btime mypullback_dropgrad($f, $x, $c);

@btime ($pb_full)(1);
@btime ($pb_fix)(1);
@btime ($pb_dropgrad)(1);
1 Like

I generally using a Fix or a closure. I haven’t seen a meaningful difference either.

Zygote unthunks everything AFAICT Zygote.jl/src/compiler/chainrules.jl at master · FluxML/Zygote.jl · GitHub