Reactant.jl and Lux.jl don't work with Adam optimiser

I am writing a training program for the VAEAC model, but when I use Optimisers.Adam(lr), I get an error. When I use, for example, Optimisers.Descent(lr), this problem does not occur.

Error message:

ERROR: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float32(::IrrationalConstants.Logπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/lWTip/src/macro.jl:132
  Float32(::UInt8)
   @ Base float.jl:245

My training function:

function train_vaeac(; epochs=20, lr=0.001f0, batch_size=100)
    model = VAEAC(input_dim, latent_dim, hidden_dim)
    ps, st = Lux.setup(Random.default_rng(), model)

    Reactant.set_default_backend("gpu")
    dev = reactant_device()

    ps = ps |> dev; st = st |> dev

    data = load_binary_mnist_matrix() #|> dev
    loader = make_loader(data; batchsize=batch_size, shuffle=true)
    loader_dev = DeviceIterator(dev, loader)

    ts = Lux.Training.TrainState(model, ps, st, Optimisers.Adam(lr))

    for epoch in 1:epochs
        tot = 0f0
        nb = 0
        for xb in loader_dev
            mask = Float32.(generate_mask(size(xb))) |> dev
            ε = randn(Float32, latent_dim, size(xb, 2)) |> dev

            debug_data = deepcopy((model,ps, st, mask, ε))
            _, loss, _, ts = Lux.Training.single_train_step!(Lux.AutoEnzyme(), loss_fn, (xb, mask, ε), ts)
            tot += loss
            nb += 1
        end
        @info "epoch=$epoch avg_loss=$(tot/nb)"
    end
    return ts
end

Loss function that I use:

function loss_fn(model, ps, st, (x, mask, ε))
    (logits, μq, logσq, μp, logσp), st2 = Lux.apply(model, (x, mask, ε), ps, st)
    recon = bce_with_logits_masked(logits, x, mask)
    kl = kl_diag_gaussians(μq, logσq, μp, logσp)
    (recon + kl), st2, (; recon, kl)
end

function bce_with_logits_masked(logits, x, mask)
    w = 1f0 .- mask
    per_elem = softplus.(logits) .- x .* logits
    s = sum(w .* per_elem)
    return s / size(x, 2)
end

function kl_diag_gaussians(μq, logσq, μp, logσp)
    σq2 = exp.(2f0 .* logσq)
    σp2 = exp.(2f0 .* logσp)
    t = (σq2 .+ (μq .- μp).^2) ./ σp2 .- 1f0 .+ 2f0 .* (logσp .- logσq)
    0.5f0 * sum(t) / size(μq, 2)
end

can you post the whole backtrace, and cc @avikpal

@avikpal
It’s here:


ERROR: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float32(::IrrationalConstants.Logπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/lWTip/src/macro.jl:132
  Float32(::UInt8)
   @ Base float.jl:245
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::Type{Float32}, ::Reactant.TracedRNumber{Float32})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:764
  [3] convert
    @ ./number.jl:7 [inlined]
  [4] convert(none::Type{Float32}, none::Reactant.TracedRNumber{Float32})
    @ Reactant ./<missing>:0
  [5] convert
    @ ./number.jl:7 [inlined]
  [6] call_with_reactant(::Reactant.MustThrowError, ::typeof(convert), ::Type{Float32}, ::Reactant.TracedRNumber{Float32})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
  [7] cvt1
    @ ./essentials.jl:612 [inlined]
  [8] ntuple
    @ ./ntuple.jl:49 [inlined]
  [9] convert
    @ ./essentials.jl:614 [inlined]
 [10] convert(none::Type{Tuple{…}}, none::Tuple{Reactant.TracedRNumber{…}, Reactant.TracedRNumber{…}})
    @ Reactant ./<missing>:0
 [11] call_with_reactant(::Reactant.MustThrowError, ::typeof(convert), ::Type{…}, ::Tuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:420
 [12] cvt1
    @ ./essentials.jl:612 [inlined]
 [13] ntuple
    @ ./ntuple.jl:50 [inlined]
 [14] ntuple(none::Base.var"#cvt1#1"{Tuple{…}, Tuple{…}}, none::Val{3})
    @ Reactant ./<missing>:0
 [15] cvt1
    @ ./essentials.jl:610 [inlined]
 [16] ntuple
    @ ./ntuple.jl:50 [inlined]
 [17] call_with_reactant(::typeof(ntuple), ::Base.var"#cvt1#1"{Tuple{…}, Tuple{…}}, ::Val{3})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [18] convert
    @ ./essentials.jl:614 [inlined]
 [19] setproperty!
    @ ./Base.jl:52 [inlined]
 [20] setproperty!(none::Optimisers.Leaf{…}, none::Symbol, none::Tuple{…})
    @ Reactant ./<missing>:0
 [21] setproperty!
    @ ./Base.jl:51 [inlined]
 [22] call_with_reactant(::Reactant.MustThrowError, ::typeof(setproperty!), ::Optimisers.Leaf{…}, ::Symbol, ::Tuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [23] #_update!#10
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:96 [inlined]
 [24] var"#_update!#10"(none::IdDict{…}, none::IdDict{…}, none::typeof(Optimisers._update!), none::Optimisers.Leaf{…}, none::Reactant.TracedRArray{…})
    @ Reactant ./<missing>:0
 [25] #_update!#10
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:93 [inlined]
 [26] call_with_reactant(::Optimisers.var"##_update!#10", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [27] _update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:92 [inlined]
 [28] #8
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [29] map
    @ ./tuple.jl:383 [inlined]
 [30] map
    @ ./namedtuple.jl:266 [inlined]
 [31] mapvalue
    @ ~/.julia/packages/Optimisers/V8kHf/src/utils.jl:2 [inlined]
 [32] #_update!#7
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [33] _update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:81 [inlined]
 [34] #8
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [35] map
    @ ./tuple.jl:386 [inlined]
 [36] map(none::Optimisers.var"#8#9"{…}, none::Tuple{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [37] getindex
    @ ./tuple.jl:31 [inlined]
 [38] map
    @ ./tuple.jl:386 [inlined]
 [39] call_with_reactant(::typeof(map), ::Optimisers.var"#8#9"{…}, ::Tuple{…}, ::Tuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [40] map
    @ ./namedtuple.jl:266 [inlined]
 [41] map(none::Function, none::@NamedTuple{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [42] call_with_reactant(::typeof(map), ::Function, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:764
--- the above 12 lines are repeated 1 more time ---
 [55] mapvalue
    @ ~/.julia/packages/Optimisers/V8kHf/src/utils.jl:2 [inlined]
 [56] #_update!#7
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [57] _update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:81 [inlined]
 [58] update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:77 [inlined]
 [59] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/gmUbf/ext/LuxReactantExt/training.jl:90 [inlined]
 [60] compute_gradients_internal_and_step!(none::typeof(ProbAbEx.loss_fn), none::ProbAbEx.VAEAC{…}, none::Tuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…})
    @ Reactant ./<missing>:0
 [61] GenericMemory
    @ ./boot.jl:516 [inlined]
 [62] IdDict
    @ ./iddict.jl:31 [inlined]
 [63] IdDict
    @ ./iddict.jl:49 [inlined]
 [64] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/lmG5F/src/EnzymeCore.jl:587 [inlined]
 [65] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/gmUbf/ext/LuxReactantExt/training.jl:85 [inlined]
 [66] call_with_reactant(::typeof(LuxReactantExt.compute_gradients_internal_and_step!), ::typeof(ProbAbEx.loss_fn), ::ProbAbEx.VAEAC{…}, ::Tuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [67] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/xQ2Wo/src/TracedUtils.jl:192
 [68] make_mlir_fn
    @ ~/.julia/packages/Reactant/xQ2Wo/src/TracedUtils.jl:94 [inlined]
 [69] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool, no_nan::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:373
 [70] compile_mlir!
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:364 [inlined]
 [71] compile_xla(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool, no_nan::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:894
 [72] compile_xla
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:884 [inlined]
 [73] compile(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool, sync::Bool, no_nan::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:922
 [74] compile
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:921 [inlined]
 [75] macro expansion
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:613 [inlined]
 [76] single_train_step_impl!(backend::Lux.Training.ReactantBackend, objective_function::typeof(ProbAbEx.loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxReactantExt ~/.julia/packages/Lux/gmUbf/ext/LuxReactantExt/training.jl:39
 [77] single_train_step!(backend::ADTypes.AutoEnzyme{…}, obj_fn::typeof(ProbAbEx.loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/gmUbf/src/helpers/training.jl:275
 [78] train_vaeac(; epochs::Int64, lr::Float32, batch_size::Int64)
    @ ProbAbEx ~/ProbAbEx/src/samplers/VAEAC_training.jl:135`Preformatted text`

Can you post the full script, I can run locally?

Thank you for your desire to help. The problem was solved by clearing all packages and downloading all project dependencies again.

I just hit this as well, seeming out of the blue. I have a github workflow here that reproduces the error: Updates11 · grero/RecurrentNetworkModels.jl@7ce9e3d · GitHub

This is different from the above case. Updates11 · grero/RecurrentNetworkModels.jl@7ce9e3d · GitHub the rng type being carried around is not compatible with reactant. use Random.default_rng()

Apologies for hijacking this topic. I just wanted to add that my issue was solved with this update

Hey I am hitting this too :frowning:

This is an MWE that Claude made (the real case depends on a pkg I am still working on), but it should reproduce

using Lux, Reactant, Random, Optimisers

# Setup
const dev = reactant_device(; force=true)
const rng = Random.default_rng()
Random.seed!(rng, 123)

# Minimal model: just a simple Dense layer
model = Dense(2 => 2)

# Setup parameters and state
ps, st = Lux.setup(rng, model)

# Simple MSE loss function
function loss_fn(model, ps, st, data)
    x, y = data
    ŷ, st = model(x, ps, st)
    loss = sum(abs2, ŷ .- y)
    return loss, st, (;)
end

# Create synthetic data on device
x = dev(randn(Float32, 2, 4))  # 2 features, batch size 4
y = dev(randn(Float32, 2, 4))
data = (x, y)

# This works with Descent
opt_descent = Descent(0.01f0)
train_state_descent = Lux.Training.TrainState(model, ps, st, opt_descent)
train_state_descent = train_state_descent |> dev

println("Testing with Descent optimizer...")
try
    (_, loss, _, train_state_descent) = Lux.Training.single_train_step!(
        AutoEnzyme(), loss_fn, data, train_state_descent
    )
    println("✓ Descent works! Loss: $loss")
catch e
    println("✗ Descent failed!")
    showerror(stdout, e)
end

# This should fail with Adam
opt_adam = Adam(0.001f0)
train_state_adam = Lux.Training.TrainState(model, ps, st, opt_adam)
train_state_adam = train_state_adam |> dev

println("\nTesting with Adam optimizer...")
try
    (_, loss, _, train_state_adam) = Lux.Training.single_train_step!(
        AutoEnzyme(), loss_fn, data, train_state_adam
    )
    println("✓ Adam works! Loss: $loss")
catch e
    println("✗ Adam failed!")
    showerror(stdout, e, backtrace())
end

and I get this error

Testing with Adam optimizer...
"Inconsistent guaranteed error IR 51 1 ─ %1  = Base.fieldtype(Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, _3)::Union{Type{Bool}, Type{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, Type{Adam{Float32, Tuple{Float64, Float64}, Float64}}}\n52 2 ─ %2  = (isa)(%1, Type{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}})::Bool\n   └──       goto #4 if not %2\n   3 ─ %4  = %new(Base.var\"#cvt1#1\"{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, _4)::Base.var\"#cvt1#1\"{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}\n   │   %5  = invoke Base.ntuple(%4::Base.var\"#cvt1#1\"{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, \$(QuoteNode(Val{3}()))::Val{3})::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n   └──       goto #5\n   4 ─ %7  = Base.convert(%1, _4)::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n   └──       goto #5\n   5 ┄ %9  = φ (#3 => %5, #4 => %7)::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n53 6 ─ %10 = Base.setfield!(_2, _3, %9)::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n   └──       return %10\n   "
✗ Adam failed!
MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float32(::UInt128)
   @ Base float.jl:302
  Float32(::UInt8)
   @ Base float.jl:245
  ...

Stacktrace:
  [1] top-level scope
    @ ~/juliaenvs/MWEreactant/main.jl:59
  [2] eval
    @ ./boot.jl:430 [inlined]
  [3] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:2734
  [4] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
    @ Base ./essentials.jl:1055
  [5] invokelatest(::Any, ::Any, ::Vararg{Any})
    @ Base ./essentials.jl:1052
  [6] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:271
  [7] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:181
  [8] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:276
  [9] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:179
 [10] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:38
 [11] #67
    @ ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
 [12] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
    @ Base.CoreLogging ./logging/logging.jl:524
 [13] with_logger
    @ ./logging/logging.jl:635 [inlined]
 [14] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:263
 [15] #invokelatest#2
    @ ./essentials.jl:1055 [inlined]
 [16] invokelatest(::Any)
    @ Base ./essentials.jl:1052
 [17] (::VSCodeServer.var"#64#65")()
    @ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:34

Notice that the Descent works (I guess because it is stateless ?) while any other optimizer does not. I am on 1.11.7 on ubuntu 24.04 on WSL2 if that matters.

(MWE) pkg> status
Status `~/juliaenvs/MWEreactant/Project.toml`
  [7da242da] Enzyme v0.13.83
  [b2108857] Lux v1.22.1
  [3bd65402] Optimisers v0.4.6
  [3c362404] Reactant v0.2.169
  [9a3f8284] Random v1.11.0

(MWE) pkg> status --outdated -m
Status `~/juliaenvs/MWEreactant/Manifest.toml`
⌅ [f6369f11] ForwardDiff v1.0.0 (<v1.2.1): Lux
⌅ [aea7be01] PrecompileTools v1.2.1 (<v1.3.3): julia
⌅ [7cc45869] Enzyme_jll v0.0.201+0 (<v0.0.202+0): Enzyme

@avikpal @wsmoses

move ps/st to device before constructing the TrainState, else the optimiser states won’t be traced correctly.

julia> ps = ps |> dev
(weight = ConcreteIFRTArray{Float32, 2, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}, Nothing}(Float32[-1.1059412 -0.40677243; 0.051962912 0.21263216]), bias = ConcreteIFRTArray{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}, Nothing}(Float32[-0.6907691, 0.55278593]))

julia> st = st |> dev
NamedTuple()

julia> opt_adam = Adam(0.001f0)
Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8)

julia> train_state_adam = Lux.Training.TrainState(model, ps, st, opt_adam)
TrainState
    model: Dense(2 => 2)
    # of parameters: 6
    # of states: 0
    optimizer: ReactantOptimiser{Adam{ConcreteIFRTNumber{Float32, ShardInfo{NoSharding, Nothing}}, Tuple{ConcreteIFRTNumber{Float64, ShardInfo{NoSharding, Nothing}}, ConcreteIFRTNumber{Float64, ShardInfo{NoSharding, Nothing}}}, ConcreteIFRTNumber{Float64, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=ConcreteIFRTNumber{Float32, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.001f0), beta=(ConcreteIFRTNumber{Float64, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), ConcreteIFRTNumber{Float64, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=ConcreteIFRTNumber{Float64, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
    step: 0

julia> (_, loss, _, train_state_adam) = Lux.Training.single_train_step!(
               AutoEnzyme(), loss_fn, data, train_state_adam
           )
((weight = Concre...