Oops yeah my bad (shouldn’t respond questions in the morning without coffee
) , I did not look at the code properly, @compile is not needed for single_train_step! but only for the model forward pass.
You original code should just work with the following changes:
--- ../Lux.jl/envs/reactant/test.jl 2025-10-10 13:06:37.612543325 -0400
+++ ../Lux.jl/envs/reactant/test2.jl 2025-10-10 13:08:49.101224655 -0400
@@ -19,10 +19,11 @@
Dense(4 => 2, gelu),
)
-ps, st = Lux.setup(rng, mod)
+ps, st = Lux.setup(rng, mod) |> xdev
ϕ = randn(Float32, 2, 10) |> xdev
-y, _ = mod(ϕ, ps, st)
+compiled_mod = @compile mod(ϕ, ps, st)
+y, _ = compiled_mod(ϕ, ps, st)
function _loss(m, p, s, x)
y, st = m(x, p, s)
Would it be possible to outline the minimal requirements or best practices when using AutoEnzyme() for those who are still getting familiar with Reactant?
I have updated the performance section (will take some time for the docs build to get through), and marked the sections that no longer apply to Reactant. Happy to make additional changes