In the next release, DifferentiationInterface.jl will accept constant (non-differentiated) arguments `c`

in addition to the active (differentiated) argument `x`

, I’m wondering how to implement this mechanism optimally for each backend.

For Enzyme, it’s easy, I can just use `Const(c)`

. For Zygote, there is no built-in solution, but I have found 3 possible approaches. To compute the partial pullback of a multi-argument `y = f(x, c)`

with respect to `x`

only, I can either

- Compute the full pullback of
`f`

with respect to`(x, c)`

and discard the second part - Compute the full pullback of the single-argument
`Base.Fix2(f, c)`

- Define a closure
`g`

which manually drops the gradient (using @CarloLucibello’s`Zygote.@nograd`

trick)

## Demo code

```
using Zygote: Zygote
dropgrad(x) = x
Zygote.@nograd dropgrad
function mypullback_full(f, x, c)
_, pb = Zygote.pullback(f, x, c)
return first ∘ pb
end
function mypullback_fix(f, x, c)
_, pb = Zygote.pullback(Base.Fix2(f, c), x)
return only ∘ pb
end
function mypullback_dropgrad(f, x, c)
g(x, c) = f(x, dropgrad(c))
_, pb = Zygote.pullback(g, x, c)
return first ∘ pb
end
```

My question is: which is preferrable in general? And in the specific case of neural networks like Lux.jl (ping @avikpal) or Flux.jl?

Note that I am interested in both the time it takes to compute the pullback, and the time it takes to apply the resulting closure.

Ideally, 1 would be sufficient because the thunking mechanism means useless cotangents (here `dy/dc`

) never get materialized. In practice, @oxinabox once told me that `Thunk`

s aren’t actually used, so maybe this solution is wasteful?

Intuitively, I would go for 2 and use `Base.Fix2`

because `Zygote.pullback`

doesn’t compute the cotangent of the function object itself (unlike `ChainRulesCore.rrule_via_ad`

). Thus, hiding the `c`

inside the function seems like a good option?

In any case, 3 seems suboptimal due to the additional closure, and the fact that `dy/dc`

is computed anyway, and only dropped at the end. But I’d love to have some confirmation from people familiar with Zygote.

Here’s some benchmarking code if you want to play around. I didn’t observe any meaningful differences but my test case is too simple to matter.

## Benchmarking code

```
using BenchmarkTools
f(x, c) = sum(sin.(x) .* cos.(c)) # change this
x, c = rand(10), rand(10); # change this
pb_full = @btime mypullback_full($f, $x, $c);
pb_fix = @btime mypullback_fix($f, $x, $c);
pb_dropgrad = @btime mypullback_dropgrad($f, $x, $c);
@btime ($pb_full)(1);
@btime ($pb_fix)(1);
@btime ($pb_dropgrad)(1);
```