Speed of Nested AD in Enzyme

I am trying to implement nested AD in Enzyme for application in PINNs. The following is a MWE.

using Enzyme, Lux, Random, ComponentArrays, LinearAlgebra
n = 100
x_batch = randn(2, n)
y_batch = randn(2, n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);

nnfunc(x, y, psarray, st) = first(model((x, y), ComponentArray(psarray), st))[1]
function batcherror(xb, yb, psarray, st)
    val = zeros(n)
    for k = 1:n
        z = xb[:, k]
        dz = [0.0, 0.0]
        Enzyme.autodiff(Enzyme.Reverse, nnfunc, Active, Duplicated(z, dz), Duplicated(yb[:, k], zeros(2)), Const(psarray), Const(st))
        val[k] = norm(dz)
    end
    return sum(val)
end
psarr = getdata(ps)
psarrnew = Enzyme.autodiff(Enzyme.Reverse,batcherror,Active,Const(x_batch),Const(y_batch),Active(psarr),Const(st))

However, the code is very slow. I am wondering if there is an alternate and fast manner in which this can be accomplished.

Thank you

1 Like

Correct me if I am wrong here, but it seems like you are trying to compute the norm of the “batched” Jacobian. Looping over the batch dim will inevitably with slow and scales with the batch size. Instead there are 2 options

  1. Use BatchDuplicated from Enzyme
  2. For structured cases like the one above (i.e. cases where the NN doesn’t contain batch mixing ops like BatchNorm) you can use batched_jacobian (Lux has it implemented for Zygote and ForwardDiff, Enzyme is WIP here Lux.jl/ext/LuxEnzymeExt/batched_autodiff.jl at ap/ho-enzyme · LuxDL/Lux.jl · GitHub)
1 Like

Isn’t BatchDuplicated most efficient for small “batch sizes” of order 10-20? Can it really scale to 100 and more without impacts?

It needs to be chunked

1 Like

It depends on the function. Enzyme’s batching can give asymptotic improvements on top of the manual for loop, depending on the code.

Thank you. I tried to play around with BatchDuplicated and noticed that it is used to calculated the Jacobian-Vector products for multiple vectors at the same x. E.g.

f(x) = x[1]^2 + 2*x[2]^2
x = [1.0,2.0]
∂x = ([1.0,0.0],[0.0,1.0])
autodiff(Forward, f, BatchDuplicated, BatchDuplicated(x, ∂x))

The above code results in [1.0,8.0] which is just the gradient/Jacobian. However, I am interested in gradients at multiple x-values. Is this doable with BatchDuplicated? Actually, I am also not sure if the forward-mode or the reverse-mode AD can theoretically give a speed up on batch data in the MWE I posted above.

There was a mistake in the first code I posted. After a few changes, the code is as below (it first differentiates w.r.t. one of the inputs of a network, does something and then differentiates w.r.t. the parameters of the network.)

using Enzyme, Lux, Random, ComponentArrays, LinearAlgebra
n = 2
x_batch = randn(2, n)
y_batch = randn(2, n)
model = Chain(Parallel(vcat, Dense(2, 1, tanh), Dense(2, 1, tanh)), Dense(2, 1, tanh))
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(Xoshiro(0), model);
psaxes = getaxes(ComponentArray(ps))

nnfunc(x, y, psarray) = first(model((x, y), ComponentArray(psarray, psaxes), st))[1]
psarray = getdata(ComponentArray(ps))

function batch_error(xb, yb, psarray)
    val = zeros(n)
    for i = 1 : n
        dx = zeros(2)
        Enzyme.autodiff(Enzyme.Reverse, nnfunc, Active, Duplicated(xb[:, i], dx), Duplicated(yb[:, i], zeros(2)), Duplicated(psarray, zeros(Float32, size(psarray))))
        val[i] = sum(dx.^2)
    end
    return sum(val)
end

dpsarray = zeros(Float32, size(psarray))
Enzyme.autodiff(Enzyme.Reverse, batch_error, Active, Duplicated(x_batch,zeros(size(x_batch))), Duplicated(y_batch,zeros(size(y_batch))), Duplicated(psarray, dpsarray))

However, I get a compilation error

Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn

…

Did not have return index set when differentiating function
 call  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } %5({} addrspace(10)* %.fca.0.extract5, {} addrspace(10)* %.fca.0.0.extract9, {} addrspace(10)* %.fca.0.1.extract) #6, !dbg !20, !noalias !9
...

Oh that’s Julia 1.11 continuing to cause more problems.

I presume you don’t hit that issue with Julia 1.10.

File this as an issue on Enzyme.jl, say your versions of packages (and double check you’re using the latest Enzyme.jl)

The following error appears with Julia 1.10.6 and Enzyme 13.17

Enzyme compilation failed.
Current scope: 
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="128669676925648" "enzymejl_parmtype_ref"="1" [3 x {} addrspace(10)*] @preprocess_julia_runtime_generic_augfwd_5694_inner.1({} addrspace(10)* nocapture nofree noundef nonnull readnone "enzyme_inactive" "enzyme_type"="{[-1]:Pointer}" "enzymejl_parmtype"="128669386720768" "enzymejl_parmtype_ref"="2" %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="128668521679056" "enzymejl_parmtype_ref"="2" %1, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="128668521679056" "enzymejl_parmtype_ref"="2" %2) local_unnamed_addr #5 !dbg !47 {
entry:
  %3 = call {}*** @julia.get_pgcstack() #9, !noalias !48
  %current_task1.i6 = getelementptr inbounds {}**, {}*** %3, i64 -14
  %current_task1.i = bitcast {}*** %current_task1.i6 to {}**
  %ptls_field.i7 = getelementptr inbounds {}**, {}*** %3, i64 2
  %4 = bitcast {}*** %ptls_field.i7 to i64***
  %ptls_load.i89 = load i64**, i64*** %4, align 8, !tbaa !11, !noalias !48
  %5 = getelementptr inbounds i64*, i64** %ptls_load.i89, i64 2
  %safepoint.i = load i64*, i64** %5, align 8, !tbaa !15, !noalias !48
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #9, !dbg !51, !noalias !48
  fence syncscope("singlethread") seq_cst
  %6 = call { { {} addrspace(10)* }, { {} addrspace(10)* } } inttoptr (i64 128668911573376 to { { {} addrspace(10)* }, { {} addrspace(10)* } } ({} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*)*)({} addrspace(10)* addrspacecast ({}* inttoptr (i64 128669386720864 to {}*) to {} addrspace(10)*), {} addrspace(10)* nonnull %1, {} addrspace(10)* nonnull %2) #9, !dbg !53
  %7 = extractvalue { { {} addrspace(10)* }, { {} addrspace(10)* } } %6, 0, !dbg !57
  %8 = extractvalue { { {} addrspace(10)* }, { {} addrspace(10)* } } %6, 1, !dbg !57
  %box.i = call noalias nonnull dereferenceable(8) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@float, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1.i, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 128669675304080 to {}*) to {} addrspace(10)*)) #10, !dbg !58
  %9 = bitcast {} addrspace(10)* %box.i to { {} addrspace(10)* } addrspace(10)*, !dbg !58
  %10 = extractvalue { {} addrspace(10)* } %7, 0, !dbg !58
  %11 = getelementptr { {} addrspace(10)* }, { {} addrspace(10)* } addrspace(10)* %9, i64 0, i32 0, !dbg !58
  store {} addrspace(10)* %10, {} addrspace(10)* addrspace(10)* %11, align 8, !dbg !58, !tbaa !32, !alias.scope !36, !noalias !60
  %box4.i = call noalias nonnull dereferenceable(8) "enzyme_type"="{[-1]:Pointer, [-1,-1]:Pointer, [-1,-1,0]:Pointer, [-1,-1,0,-1]:Float@float, [-1,-1,8]:Integer, [-1,-1,9]:Integer, [-1,-1,10]:Integer, [-1,-1,11]:Integer, [-1,-1,12]:Integer, [-1,-1,13]:Integer, [-1,-1,14]:Integer, [-1,-1,15]:Integer, [-1,-1,16]:Integer, [-1,-1,17]:Integer, [-1,-1,18]:Integer, [-1,-1,19]:Integer, [-1,-1,20]:Integer, [-1,-1,21]:Integer, [-1,-1,22]:Integer, [-1,-1,23]:Integer, [-1,-1,24]:Integer, [-1,-1,25]:Integer, [-1,-1,26]:Integer, [-1,-1,27]:Integer, [-1,-1,28]:Integer, [-1,-1,29]:Integer, [-1,-1,30]:Integer, [-1,-1,31]:Integer, [-1,-1,32]:Integer, [-1,-1,33]:Integer, [-1,-1,34]:Integer, [-1,-1,35]:Integer, [-1,-1,36]:Integer, [-1,-1,37]:Integer, [-1,-1,38]:Integer, [-1,-1,39]:Integer}" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1.i, i64 noundef 8, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 128669675304080 to {}*) to {} addrspace(10)*)) #10, !dbg !58
...
    @ ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:396 [inlined]
 [26] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::typeof(batch_error), ::Type{Active}, ::Duplicated{Matrix{Float64}}, ::Duplicated{Matrix{Float64}}, ::Duplicated{Vector{Float32}}, ::Const{@NamedTuple{layer_1::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, layer_2::@NamedTuple{}}})
    @ Enzyme ~/.julia/packages/Enzyme/fpA3W/src/Enzyme.jl:524
 [27] top-level scope