Hi,
I am trying to use Lux
with Enzyme
, and I cannot get it to work in a simple example. The only complete (and working) example that I have found is this: Compiling Lux Models using Reactant.jl | Lux.jl Docs, which is focused on Reactant
compilation, and not particularly on Enzyme
itself. I copied this example as follows:
using Lux, Reactant, Enzyme, Random, Optimisers, Printf
model = Chain(
Dense(2 => 4, gelu),
Dense(4 => 4, gelu),
Dense(4 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model)
x_ra = [randn(Float32, 2, 32) for _ in 1:32]
y_ra = [xᵢ .^ 2 for xᵢ in x_ra]
ps_ra = ps |> xdev
st_ra = st |> xdev
dataloader = DeviceIterator(xdev, zip(x_ra, y_ra))
function train_model(model, ps, st, dataloader)
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
for iteration in 1:1000
for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
_, loss, _, train_state = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
if (iteration % 100 == 0 || iteration == 1) && i == 1
@printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
end
end
end
return train_state
end
train_model(model, ps_ra, st_ra, dataloader)
This works as expected. However, if I remove Reactant from the code (essentially removing all calls to the reactant device xdev
, I obtain the following error:
ERROR: LoadError: MethodError: no method matching dparameters(::Nothing)
The function `dparameters` exists, but no method is defined for this combination of argument types.
Closest candidates are:
dparameters(::Lux.Training.TrainingBackendCache, ::Static.True)
@ Lux ~/.julia/packages/Lux/HD428/src/helpers/training.jl:87
dparameters(::Lux.Training.TrainingBackendCache, ::Static.False)
@ Lux ~/.julia/packages/Lux/HD428/src/helpers/training.jl:84
dparameters(::Lux.Training.TrainingBackendCache)
@ Lux ~/.julia/packages/Lux/HD428/src/helpers/training.jl:83
Stacktrace:
[1] compute_gradients_impl(ad::AutoEnzyme{…}, obj_fn::typeof(_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ LuxEnzymeExt ~/.julia/packages/Lux/HD428/ext/LuxEnzymeExt/training.jl:3
[2] compute_gradients
@ ~/.julia/packages/Lux/HD428/src/helpers/training.jl:198 [inlined]
[3] single_train_step_impl!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/HD428/src/helpers/training.jl:301
[4] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(_loss), data::Tuple{…}, ts::Lux.Training.TrainState{…})
@ Lux.Training ~/.julia/packages/Lux/HD428/src/helpers/training.jl:276
[5] train_model(model::Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, dataloader::DeviceIterator{…})
@ Main ~/Seafile/Nube/Facultad/Yjilioff/julia/ejemplo_enzyme.jl:27
[6] top-level scope
@ ~/Seafile/Nube/Facultad/Yjilioff/julia/ejemplo_enzyme.jl:38
[7] include(fname::String)
@ Main ./sysimg.jl:38
[8] top-level scope
@ REPL[7]:1
in expression starting at /Users/iojea/Seafile/Nube/Facultad/Yjilioff/julia/ejemplo_enzyme.jl:38
Some type information was truncated. Use `show(err)` to see complete types.
I tried to track this error by following the call to dparameters
. I am not sure about what is exactly happening, but the problem seems to be that train_state.cache == nothing
. However this is also the case when the Reactant device is used, and in that case everything runs smoothly.
Since the example is pretty simple, my question is: is it possible to use AutoEnzyme()
as a direct replacement for AutoZygote()
? or Reactant
is also needed?
The big picture
My real problem is not exactly how to use Enzyme
. I am trying to solve a PDE using a neural network. In particular, I have a loss function that depends on the jacobian of the model with respect to the input. Following the docs (Nested Automatic Differentiation | Lux.jl Docs) I produced a working example using ForwardDiff
for this jacobian and AutoZygote()
for the training. I copy below a minimal working example. Note that it is not a meaningful example in a mathematical or physical sense, but just a simplified example to remark some important details.
using Lux, Random, Optimisers, Zygote
import ForwardDiff
function _loss(model,ps,st,data)
points,f = data
smodel = StatefulLuxLayer{true}(model,ps,st)
s = zero(eltype(points[1]))
for p in points
J = ForwardDiff.jacobian(smodel,p)
∇u = J[1,:]
divϕ = J[2,1] + J[3,2]
ϕ = smodel(p)[2:3]
s += sum((ϕ-∇u).^2) + (divϕ+f(p))^2
end
return s/length(points),st,nothing
end
create_points(N) = [rand(Float32,2) for _ in 1:N]
function create_points!(points)
for p in points
p .= rand(Float32,2)
end
end
function train_model(model,epochs;n_points=1000,step = 0.01f0)
rng = Xoshiro(0)
opt = Adam(step)
ps, st = Lux.setup(rng, model)
state = Training.TrainState(model,ps,st,opt)
f(x) = -4.0f0 + 2sum(x.^2)
points = create_points(n_points)
for epoch in 1:epochs
create_points!(points)
gs,l,stats,state = Training.single_train_step!(AutoZygote(), _loss, (points,f), state)
end
return state
end
layers = 15
model = Chain(
Dense(2=>layers,sigmoid),
Dense(layers=>layers,sigmoid),
Dense(layers=>layers,sigmoid),
Dense(layers=>layers,sigmoid),
Dense(layers=>3)
)
state = train_model(model,5000)
The key point here is that at each iteration I am sampling many points on a domain and that I need to compute the jacobian at each one of these points. This looks rather expensive. It would be better to use ForwardDiff.jacobian!(J,smodel,p)
with a cached J
. However, this fails due to Zygote
’s lack of support for mutating arrays.
This leads me to Enzyme
. As far as I can see, nested auto-differentiation is not supported with Enzyme
yet. However, I could use finite differences for the computation of the jacobian, provided that the gain in performance due to the reduction in the number of allocations is worth it.
Of course, using Enzyme
+ Reactant
should be even better. But if I try this I get other errores, due essentially to scalar indexing of gpu arrays. I am willing to explore this path and try to fix the errors, but I would like to begin at the beginning, by having a simple working example with plain Enzyme
first and then building from it.
Sorry for the long post, and thanks in advance.