How to use Lux with Enzyme

Hi,

I am trying to use Lux with Enzyme, and I cannot get it to work in a simple example. The only complete (and working) example that I have found is this: Compiling Lux Models using Reactant.jl | Lux.jl Docs, which is focused on Reactant compilation, and not particularly on Enzyme itself. I copied this example as follows:


using Lux, Reactant, Enzyme, Random, Optimisers, Printf

model = Chain(
    Dense(2 => 4, gelu),
    Dense(4 => 4, gelu),
    Dense(4 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)

x_ra = [randn(Float32, 2, 32) for _ in 1:32]
y_ra = [xᵢ .^ 2 for xᵢ in x_ra]
ps_ra = ps |> xdev
st_ra = st |> xdev

dataloader = DeviceIterator(xdev, zip(x_ra, y_ra))

function train_model(model, ps, st, dataloader)
    train_state = Training.TrainState(model, ps, st, Adam(0.001f0))

    for iteration in 1:1000
        for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
            _, loss, _, train_state = Training.single_train_step!(
                AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
            if (iteration % 100 == 0 || iteration == 1) && i == 1
                @printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
            end
        end
    end

    return train_state
end

train_model(model, ps_ra, st_ra, dataloader)

This works as expected. However, if I remove Reactant from the code (essentially removing all calls to the reactant device xdev, I obtain the following error:

ERROR: LoadError: MethodError: no method matching dparameters(::Nothing)
The function `dparameters` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  dparameters(::Lux.Training.TrainingBackendCache, ::Static.True)
   @ Lux ~/.julia/packages/Lux/HD428/src/helpers/training.jl:87
  dparameters(::Lux.Training.TrainingBackendCache, ::Static.False)
   @ Lux ~/.julia/packages/Lux/HD428/src/helpers/training.jl:84
  dparameters(::Lux.Training.TrainingBackendCache)
   @ Lux ~/.julia/packages/Lux/HD428/src/helpers/training.jl:83

Stacktrace:
 [1] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::typeof(_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
   @ LuxEnzymeExt ~/.julia/packages/Lux/HD428/ext/LuxEnzymeExt/training.jl:3
 [2] compute_gradients
   @ ~/.julia/packages/Lux/HD428/src/helpers/training.jl:198 [inlined]
 [3] single_train_step_impl!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
   @ Lux.Training ~/.julia/packages/Lux/HD428/src/helpers/training.jl:301
 [4] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
   @ Lux.Training ~/.julia/packages/Lux/HD428/src/helpers/training.jl:276
 [5] train_model(model::Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, dataloader::DeviceIterator{…})
   @ Main ~/Seafile/Nube/Facultad/Yjilioff/julia/ejemplo_enzyme.jl:27
 [6] top-level scope
   @ ~/Seafile/Nube/Facultad/Yjilioff/julia/ejemplo_enzyme.jl:38
 [7] include(fname::String)
   @ Main ./sysimg.jl:38
 [8] top-level scope
   @ REPL[7]:1
in expression starting at /Users/iojea/Seafile/Nube/Facultad/Yjilioff/julia/ejemplo_enzyme.jl:38
Some type information was truncated. Use `show(err)` to see complete types.

I tried to track this error by following the call to dparameters. I am not sure about what is exactly happening, but the problem seems to be that train_state.cache == nothing. However this is also the case when the Reactant device is used, and in that case everything runs smoothly.

Since the example is pretty simple, my question is: is it possible to use AutoEnzyme() as a direct replacement for AutoZygote()? or Reactant is also needed?

The big picture

My real problem is not exactly how to use Enzyme. I am trying to solve a PDE using a neural network. In particular, I have a loss function that depends on the jacobian of the model with respect to the input. Following the docs (Nested Automatic Differentiation | Lux.jl Docs) I produced a working example using ForwardDiff for this jacobian and AutoZygote() for the training. I copy below a minimal working example. Note that it is not a meaningful example in a mathematical or physical sense, but just a simplified example to remark some important details.

using Lux, Random, Optimisers, Zygote
import ForwardDiff

function _loss(model,ps,st,data)
    points,f = data
    smodel = StatefulLuxLayer{true}(model,ps,st)
    s = zero(eltype(points[1]))
    for p in points
        J = ForwardDiff.jacobian(smodel,p) 
        ∇u   = J[1,:] 
        divϕ = J[2,1] + J[3,2] 
        ϕ   = smodel(p)[2:3] 
        s   += sum((ϕ-∇u).^2) + (divϕ+f(p))^2
    end
    return s/length(points),st,nothing
end

create_points(N) = [rand(Float32,2) for _ in 1:N]

function create_points!(points) 
    for p in points
        p .= rand(Float32,2)
    end
end

function train_model(model,epochs;n_points=1000,step = 0.01f0)
    rng = Xoshiro(0)
    opt = Adam(step)
    ps, st = Lux.setup(rng, model)
    state  = Training.TrainState(model,ps,st,opt) 
    f(x) = -4.0f0 + 2sum(x.^2)
    points = create_points(n_points)
    for epoch in 1:epochs
        create_points!(points)
        gs,l,stats,state = Training.single_train_step!(AutoZygote(), _loss, (points,f), state)
    end
    return state
end


layers = 15
model = Chain(
            Dense(2=>layers,sigmoid),
            Dense(layers=>layers,sigmoid),
            Dense(layers=>layers,sigmoid),
            Dense(layers=>layers,sigmoid),
            Dense(layers=>3)
        )

state = train_model(model,5000)

The key point here is that at each iteration I am sampling many points on a domain and that I need to compute the jacobian at each one of these points. This looks rather expensive. It would be better to use ForwardDiff.jacobian!(J,smodel,p) with a cached J. However, this fails due to Zygote’s lack of support for mutating arrays.

This leads me to Enzyme. As far as I can see, nested auto-differentiation is not supported with Enzyme yet. However, I could use finite differences for the computation of the jacobian, provided that the gain in performance due to the reduction in the number of allocations is worth it.

Of course, using Enzyme + Reactant should be even better. But if I try this I get other errores, due essentially to scalar indexing of gpu arrays. I am willing to explore this path and try to fix the errors, but I would like to begin at the beginning, by having a simple working example with plain Enzyme first and then building from it.

Sorry for the long post, and thanks in advance.

1 Like

@avikpal

1 Like

can you try the latest release from yesterday night? I fixed the bug with AutoEnzyme() in your first example there

Enzyme natively supports nested AD (up to order 2 due to some complications with GPUCompiler IIRC). The statement in that manual is mostly in reference to the auto-switching what happens with ForwardDiff and Zygote.

Incidentally, I also just started looking into it docs: add 3rd order AD example using Reactant by avik-pal · Pull Request #1097 · LuxDL/Lux.jl · GitHub.

1 Like

Ok. Excellent! The simple example with AutoEnzyme() is now working. I will try to integrate this in my main code, including the nested differentiation. I’ll let you know what happens.
Thanks a lot!

Well… I was finally able to come back to this code, and I am facing new errors.

Essentially, I am trying to make Enzyme work in a simple example with nested differentiation. A simplification of my code looks like this:

using Lux
using Random
using Optimisers
using MLUtils
using Enzyme


function _loss(model,ps,st,data)
    pts,∇u,g = data
    dout = Float32[1,0,0]
    out  = Float32[0,0,0]
    s = zero(eltype(pts[1]))
    for p in pts
        ∇u .= Float32[0,0]
        autodiff(Reverse,Const(g),Const,Duplicated(out,dout),Duplicated(p,∇u))
        s += sum(∇u.^2)
    end
    s/length(pts),st,nothing
end


function train_model!(model,state,epochs;N=1000,step = 0.01f0)
    ∇u = Float32[0,0]
    f(x) = Float32(-4)
    pts  = [rand(Float32,2) for _ in 1:N]
    for epoch in 1:epochs
        smodel = StatefulLuxLayer{true}(model,state.parameters,state.states)
        function g(out,x) 
                out .= smodel(x)
                nothing
        end
        gs,l,stats,state = Training.single_train_step!(AutoEnzyme(), _loss, (pts,∇u,g), state)
    end
    model,state
end


layers = 3
model = Chain(
            Dense(2=>layers,sigmoid),
            Dense(layers=>layers,sigmoid),
            Dense(layers=>3)
        )
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) 
opt = Adam()
state  = Training.TrainState(model,ps,st,opt) 
model,state = train_model!(model,state,1000)

The idea is as follows: ∇u is a cache vector where I store the gradient of the first component of the NN. pts is a vector of points. The loss function evaluates the gradient of the first component at each point in pts, storing it in ∇u, and accumulates the norm of ∇u. For the definition of the function g I am following Enzyme’s documentation on pullbacks (FAQ · Enzyme.jl).

If I run this code, I get the following error:

ERROR: LoadError: TypeError: in typeassert, expected LLVM.LoadInst, got a value of type LLVM.CallInst
Stacktrace:
  [1] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{…}, imported::Set{…}, f::LLVM.Function, deletedfns::Vector{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ottqJ/src/compiler/validation.jl:462
  [2] check_ir!(job::GPUCompiler.CompilerJob, errors::Vector{Tuple{String, Vector{…}, Any}}, mod::LLVM.Module)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ottqJ/src/compiler/validation.jl:400
  [3] check_ir
    @ ~/.julia/packages/Enzyme/ottqJ/src/compiler/validation.jl:181 [inlined]
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ottqJ/src/compiler.jl:3290
  [5] codegen
    @ ~/.julia/packages/Enzyme/ottqJ/src/compiler.jl:3218 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ottqJ/src/compiler.jl:5265
  [7] cached_compilation(job::GPUCompiler.CompilerJob)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ottqJ/src/compiler.jl:5306
  [8] thunkbase(mi::Core.MethodInstance, World::UInt64, FA::Type{…}, A::Type{…}, TT::Type, Mode::Enzyme.API.CDerivativeMode, width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ottqJ/src/compiler.jl:5410
  [9] thunk_generator(world::UInt64, source::LineNumberNode, FA::Type, A::Type, TT::Type, Mode::Enzyme.API.CDerivativeMode, Width::Int64, ModifiedBetween::NTuple{…} where N, ReturnPrimal::Bool, ShadowInit::Bool, ABI::Type, ErrIfFuncWritten::Bool, RuntimeActivity::Bool, self::Any, fakeworld::Any, fa::Type, a::Type, tt::Type, mode::Type, width::Type, modifiedbetween::Type, returnprimal::Type, shadowinit::Type, abi::Type, erriffuncwritten::Type, runtimeactivity::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/ottqJ/src/compiler.jl:5562
 [10] autodiff
    @ ~/.julia/packages/Enzyme/ottqJ/src/Enzyme.jl:485 [inlined]
 [11] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::typeof(_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxEnzymeExt ~/.julia/packages/Lux/H3WdN/ext/LuxEnzymeExt/training.jl:8
 [12] compute_gradients
    @ ~/.julia/packages/Lux/H3WdN/src/helpers/training.jl:198 [inlined]
 [13] single_train_step_impl!
    @ ~/.julia/packages/Lux/H3WdN/src/helpers/training.jl:301 [inlined]
 [14] single_train_step!
    @ ~/.julia/packages/Lux/H3WdN/src/helpers/training.jl:276 [inlined]
 [15] train_model!(model::Chain{…}, state::Lux.Training.TrainState{…}, epochs::Int64; N::Int64, step::Float32)
    @ Main ~/Seafile/Nube/Facultad/Yjilioff/julia/test_enzyme_basic.jl:32
 [16] train_model!(model::Chain{…}, state::Lux.Training.TrainState{…}, epochs::Int64)
    @ Main ~/Seafile/Nube/Facultad/Yjilioff/julia/test_enzyme_basic.jl:22
 [17] top-level scope
    @ ~/Seafile/Nube/Facultad/Yjilioff/julia/test_enzyme_basic.jl:48
 [18] include(fname::String)
    @ Main ./sysimg.jl:38
 [19] top-level scope
    @ REPL[8]:1
in expression starting at /Users/iojea/Seafile/Nube/Facultad/Yjilioff/julia/test_enzyme_basic.jl:48
Some type information was truncated. Use `show(err)` to see complete types.

However, by defining smodel and g outside of the training function I can evaluate the loss function. Hence: the problem seems to be in the nested differentiation, but I do not know what is the proper way to do this.

Any ideas?

Thanks!

1 Like