In the next release, DifferentiationInterface.jl will accept constant (non-differentiated) arguments c
in addition to the active (differentiated) argument x
, I’m wondering how to implement this mechanism optimally for each backend.
For Enzyme, it’s easy, I can just use Const(c)
. For Zygote, there is no built-in solution, but I have found 3 possible approaches. To compute the partial pullback of a multi-argument y = f(x, c)
with respect to x
only, I can either
- Compute the full pullback of
f
with respect to(x, c)
and discard the second part - Compute the full pullback of the single-argument
Base.Fix2(f, c)
- Define a closure
g
which manually drops the gradient (using @CarloLucibello’sZygote.@nograd
trick)
Demo code
using Zygote: Zygote
dropgrad(x) = x
Zygote.@nograd dropgrad
function mypullback_full(f, x, c)
_, pb = Zygote.pullback(f, x, c)
return first ∘ pb
end
function mypullback_fix(f, x, c)
_, pb = Zygote.pullback(Base.Fix2(f, c), x)
return only ∘ pb
end
function mypullback_dropgrad(f, x, c)
g(x, c) = f(x, dropgrad(c))
_, pb = Zygote.pullback(g, x, c)
return first ∘ pb
end
My question is: which is preferrable in general? And in the specific case of neural networks like Lux.jl (ping @avikpal) or Flux.jl?
Note that I am interested in both the time it takes to compute the pullback, and the time it takes to apply the resulting closure.
Ideally, 1 would be sufficient because the thunking mechanism means useless cotangents (here dy/dc
) never get materialized. In practice, @oxinabox once told me that Thunk
s aren’t actually used, so maybe this solution is wasteful?
Intuitively, I would go for 2 and use Base.Fix2
because Zygote.pullback
doesn’t compute the cotangent of the function object itself (unlike ChainRulesCore.rrule_via_ad
). Thus, hiding the c
inside the function seems like a good option?
In any case, 3 seems suboptimal due to the additional closure, and the fact that dy/dc
is computed anyway, and only dropped at the end. But I’d love to have some confirmation from people familiar with Zygote.
Here’s some benchmarking code if you want to play around. I didn’t observe any meaningful differences but my test case is too simple to matter.
Benchmarking code
using BenchmarkTools
f(x, c) = sum(sin.(x) .* cos.(c)) # change this
x, c = rand(10), rand(10); # change this
pb_full = @btime mypullback_full($f, $x, $c);
pb_fix = @btime mypullback_fix($f, $x, $c);
pb_dropgrad = @btime mypullback_dropgrad($f, $x, $c);
@btime ($pb_full)(1);
@btime ($pb_fix)(1);
@btime ($pb_dropgrad)(1);