In the next release, DifferentiationInterface.jl will accept constant (non-differentiated) arguments c in addition to the active (differentiated) argument x, I’m wondering how to implement this mechanism optimally for each backend.
For Enzyme, it’s easy, I can just use Const(c). For Zygote, there is no built-in solution, but I have found 3 possible approaches. To compute the partial pullback of a multi-argument y = f(x, c) with respect to x only, I can either
- Compute the full pullback of
fwith respect to(x, c)and discard the second part - Compute the full pullback of the single-argument
Base.Fix2(f, c) - Define a closure
gwhich manually drops the gradient (using @CarloLucibello’sZygote.@nogradtrick)
Demo code
using Zygote: Zygote
dropgrad(x) = x
Zygote.@nograd dropgrad
function mypullback_full(f, x, c)
_, pb = Zygote.pullback(f, x, c)
return first ∘ pb
end
function mypullback_fix(f, x, c)
_, pb = Zygote.pullback(Base.Fix2(f, c), x)
return only ∘ pb
end
function mypullback_dropgrad(f, x, c)
g(x, c) = f(x, dropgrad(c))
_, pb = Zygote.pullback(g, x, c)
return first ∘ pb
end
My question is: which is preferrable in general? And in the specific case of neural networks like Lux.jl (ping @avikpal) or Flux.jl?
Note that I am interested in both the time it takes to compute the pullback, and the time it takes to apply the resulting closure.
Ideally, 1 would be sufficient because the thunking mechanism means useless cotangents (here dy/dc) never get materialized. In practice, @oxinabox once told me that Thunks aren’t actually used, so maybe this solution is wasteful?
Intuitively, I would go for 2 and use Base.Fix2 because Zygote.pullback doesn’t compute the cotangent of the function object itself (unlike ChainRulesCore.rrule_via_ad). Thus, hiding the c inside the function seems like a good option?
In any case, 3 seems suboptimal due to the additional closure, and the fact that dy/dc is computed anyway, and only dropped at the end. But I’d love to have some confirmation from people familiar with Zygote.
Here’s some benchmarking code if you want to play around. I didn’t observe any meaningful differences but my test case is too simple to matter.
Benchmarking code
using BenchmarkTools
f(x, c) = sum(sin.(x) .* cos.(c)) # change this
x, c = rand(10), rand(10); # change this
pb_full = @btime mypullback_full($f, $x, $c);
pb_fix = @btime mypullback_fix($f, $x, $c);
pb_dropgrad = @btime mypullback_dropgrad($f, $x, $c);
@btime ($pb_full)(1);
@btime ($pb_fix)(1);
@btime ($pb_dropgrad)(1);