Hey I am hitting this too 
This is an MWE that Claude made (the real case depends on a pkg I am still working on), but it should reproduce
using Lux, Reactant, Random, Optimisers
# Setup
const dev = reactant_device(; force=true)
const rng = Random.default_rng()
Random.seed!(rng, 123)
# Minimal model: just a simple Dense layer
model = Dense(2 => 2)
# Setup parameters and state
ps, st = Lux.setup(rng, model)
# Simple MSE loss function
function loss_fn(model, ps, st, data)
x, y = data
ŷ, st = model(x, ps, st)
loss = sum(abs2, ŷ .- y)
return loss, st, (;)
end
# Create synthetic data on device
x = dev(randn(Float32, 2, 4)) # 2 features, batch size 4
y = dev(randn(Float32, 2, 4))
data = (x, y)
# This works with Descent
opt_descent = Descent(0.01f0)
train_state_descent = Lux.Training.TrainState(model, ps, st, opt_descent)
train_state_descent = train_state_descent |> dev
println("Testing with Descent optimizer...")
try
(_, loss, _, train_state_descent) = Lux.Training.single_train_step!(
AutoEnzyme(), loss_fn, data, train_state_descent
)
println("✓ Descent works! Loss: $loss")
catch e
println("✗ Descent failed!")
showerror(stdout, e)
end
# This should fail with Adam
opt_adam = Adam(0.001f0)
train_state_adam = Lux.Training.TrainState(model, ps, st, opt_adam)
train_state_adam = train_state_adam |> dev
println("\nTesting with Adam optimizer...")
try
(_, loss, _, train_state_adam) = Lux.Training.single_train_step!(
AutoEnzyme(), loss_fn, data, train_state_adam
)
println("✓ Adam works! Loss: $loss")
catch e
println("✗ Adam failed!")
showerror(stdout, e, backtrace())
end
and I get this error
Testing with Adam optimizer...
"Inconsistent guaranteed error IR 51 1 ─ %1 = Base.fieldtype(Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, _3)::Union{Type{Bool}, Type{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, Type{Adam{Float32, Tuple{Float64, Float64}, Float64}}}\n52 2 ─ %2 = (isa)(%1, Type{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}})::Bool\n └── goto #4 if not %2\n 3 ─ %4 = %new(Base.var\"#cvt1#1\"{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, _4)::Base.var\"#cvt1#1\"{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}\n │ %5 = invoke Base.ntuple(%4::Base.var\"#cvt1#1\"{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}}}, \$(QuoteNode(Val{3}()))::Val{3})::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n └── goto #5\n 4 ─ %7 = Base.convert(%1, _4)::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n └── goto #5\n 5 ┄ %9 = φ (#3 => %5, #4 => %7)::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n53 6 ─ %10 = Base.setfield!(_2, _3, %9)::Tuple{Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}, Union{Reactant.TracedRArray{Float32, 2}, NTuple{_A, Any} where _A}}\n └── return %10\n "
✗ Adam failed!
MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
(::Type{T})(::T) where T<:Number
@ Core boot.jl:900
Float32(::UInt128)
@ Base float.jl:302
Float32(::UInt8)
@ Base float.jl:245
...
Stacktrace:
[1] top-level scope
@ ~/juliaenvs/MWEreactant/main.jl:59
[2] eval
@ ./boot.jl:430 [inlined]
[3] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
@ Base ./loading.jl:2734
[4] invokelatest(::Any, ::Any, ::Vararg{Any}; kwargs::@Kwargs{})
@ Base ./essentials.jl:1055
[5] invokelatest(::Any, ::Any, ::Vararg{Any})
@ Base ./essentials.jl:1052
[6] inlineeval(m::Module, code::String, code_line::Int64, code_column::Int64, file::String; softscope::Bool)
@ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:271
[7] (::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:181
[8] withpath(f::VSCodeServer.var"#69#74"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, path::String)
@ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:276
[9] (::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:179
[10] hideprompt(f::VSCodeServer.var"#68#73"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams})
@ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/repl.jl:38
[11] #67
@ ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:150 [inlined]
[12] with_logstate(f::VSCodeServer.var"#67#72"{Bool, Bool, Bool, Module, String, Int64, Int64, String, VSCodeServer.ReplRunCodeRequestParams}, logstate::Base.CoreLogging.LogState)
@ Base.CoreLogging ./logging/logging.jl:524
[13] with_logger
@ ./logging/logging.jl:635 [inlined]
[14] (::VSCodeServer.var"#66#71"{VSCodeServer.ReplRunCodeRequestParams})()
@ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:263
[15] #invokelatest#2
@ ./essentials.jl:1055 [inlined]
[16] invokelatest(::Any)
@ Base ./essentials.jl:1052
[17] (::VSCodeServer.var"#64#65")()
@ VSCodeServer ~/.vscode-server/extensions/julialang.language-julia-1.149.2/scripts/packages/VSCodeServer/src/eval.jl:34
Notice that the Descent works (I guess because it is stateless ?) while any other optimizer does not. I am on 1.11.7 on ubuntu 24.04 on WSL2 if that matters.
(MWE) pkg> status
Status `~/juliaenvs/MWEreactant/Project.toml`
[7da242da] Enzyme v0.13.83
[b2108857] Lux v1.22.1
[3bd65402] Optimisers v0.4.6
[3c362404] Reactant v0.2.169
[9a3f8284] Random v1.11.0
(MWE) pkg> status --outdated -m
Status `~/juliaenvs/MWEreactant/Manifest.toml`
⌅ [f6369f11] ForwardDiff v1.0.0 (<v1.2.1): Lux
⌅ [aea7be01] PrecompileTools v1.2.1 (<v1.3.3): julia
⌅ [7cc45869] Enzyme_jll v0.0.201+0 (<v0.0.202+0): Enzyme
@avikpal @wsmoses