I have the following MWE
using Enzyme, Lux, Random
n = 10
x_batch = randn(2,n)
y_batch = randn(2,n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2,1,tanh)), Dense(2,1,tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);
function f(xb,yb)
for k = 1 : n
f1(x) = first(model((x, yb[:,k]), ps, st))[1]
z = xb[:,k]
dz = [0.0,0.0]
Enzyme.autodiff(Enzyme.Reverse, f1, Active, Duplicated(z,dz))
end
end
f(x_batch,y_batch)
results in the following error :
Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)
I have read the Enzyme docs but cannot gather why this happens. Further, is this the best way to compute gradients w.r.t. one of the inputs over batch?
Thank you