I am trying to implement nested AD in Enzyme for application in PINNs. The following is a MWE.
using Enzyme, Lux, Random, ComponentArrays, LinearAlgebra
n = 100
x_batch = randn(2, n)
y_batch = randn(2, n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);
nnfunc(x, y, psarray, st) = first(model((x, y), ComponentArray(psarray), st))[1]
function batcherror(xb, yb, psarray, st)
val = zeros(n)
for k = 1:n
z = xb[:, k]
dz = [0.0, 0.0]
Enzyme.autodiff(Enzyme.Reverse, nnfunc, Active, Duplicated(z, dz), Duplicated(yb[:, k], zeros(2)), Const(psarray), Const(st))
val[k] = norm(dz)
end
return sum(val)
end
psarr = getdata(ps)
psarrnew = Enzyme.autodiff(Enzyme.Reverse,batcherror,Active,Const(x_batch),Const(y_batch),Active(psarr),Const(st))
However, the code is very slow. I am wondering if there is an alternate and fast manner in which this can be accomplished.
Thank you