Speed of Nested AD in Enzyme

I am trying to implement nested AD in Enzyme for application in PINNs. The following is a MWE.

using Enzyme, Lux, Random, ComponentArrays, LinearAlgebra
n = 100
x_batch = randn(2, n)
y_batch = randn(2, n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);

nnfunc(x, y, psarray, st) = first(model((x, y), ComponentArray(psarray), st))[1]
function batcherror(xb, yb, psarray, st)
    val = zeros(n)
    for k = 1:n
        z = xb[:, k]
        dz = [0.0, 0.0]
        Enzyme.autodiff(Enzyme.Reverse, nnfunc, Active, Duplicated(z, dz), Duplicated(yb[:, k], zeros(2)), Const(psarray), Const(st))
        val[k] = norm(dz)
    end
    return sum(val)
end
psarr = getdata(ps)
psarrnew = Enzyme.autodiff(Enzyme.Reverse,batcherror,Active,Const(x_batch),Const(y_batch),Active(psarr),Const(st))

However, the code is very slow. I am wondering if there is an alternate and fast manner in which this can be accomplished.

Thank you

1 Like

Correct me if I am wrong here, but it seems like you are trying to compute the norm of the “batched” Jacobian. Looping over the batch dim will inevitably with slow and scales with the batch size. Instead there are 2 options

  1. Use BatchDuplicated from Enzyme
  2. For structured cases like the one above (i.e. cases where the NN doesn’t contain batch mixing ops like BatchNorm) you can use batched_jacobian (Lux has it implemented for Zygote and ForwardDiff, Enzyme is WIP here Lux.jl/ext/LuxEnzymeExt/batched_autodiff.jl at ap/ho-enzyme · LuxDL/Lux.jl · GitHub)
1 Like

Isn’t BatchDuplicated most efficient for small “batch sizes” of order 10-20? Can it really scale to 100 and more without impacts?