Correct me if I am wrong here, but it seems like you are trying to compute the norm of the “batched” Jacobian. Looping over the batch dim will inevitably with slow and scales with the batch size. Instead there are 2 options
- Use
BatchDuplicated
from Enzyme - For structured cases like the one above (i.e. cases where the NN doesn’t contain batch mixing ops like BatchNorm) you can use batched_jacobian (Lux has it implemented for Zygote and ForwardDiff, Enzyme is WIP here Lux.jl/ext/LuxEnzymeExt/batched_autodiff.jl at ap/ho-enzyme · LuxDL/Lux.jl · GitHub)