Reactant.jl and Lux.jl don't work with Adam optimiser

I am writing a training program for the VAEAC model, but when I use Optimisers.Adam(lr), I get an error. When I use, for example, Optimisers.Descent(lr), this problem does not occur.

Error message:

ERROR: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float32(::IrrationalConstants.Logπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/lWTip/src/macro.jl:132
  Float32(::UInt8)
   @ Base float.jl:245

My training function:

function train_vaeac(; epochs=20, lr=0.001f0, batch_size=100)
    model = VAEAC(input_dim, latent_dim, hidden_dim)
    ps, st = Lux.setup(Random.default_rng(), model)

    Reactant.set_default_backend("gpu")
    dev = reactant_device()

    ps = ps |> dev; st = st |> dev

    data = load_binary_mnist_matrix() #|> dev
    loader = make_loader(data; batchsize=batch_size, shuffle=true)
    loader_dev = DeviceIterator(dev, loader)

    ts = Lux.Training.TrainState(model, ps, st, Optimisers.Adam(lr))

    for epoch in 1:epochs
        tot = 0f0
        nb = 0
        for xb in loader_dev
            mask = Float32.(generate_mask(size(xb))) |> dev
            ε = randn(Float32, latent_dim, size(xb, 2)) |> dev

            debug_data = deepcopy((model,ps, st, mask, ε))
            _, loss, _, ts = Lux.Training.single_train_step!(Lux.AutoEnzyme(), loss_fn, (xb, mask, ε), ts)
            tot += loss
            nb += 1
        end
        @info "epoch=$epoch avg_loss=$(tot/nb)"
    end
    return ts
end

Loss function that I use:

function loss_fn(model, ps, st, (x, mask, ε))
    (logits, μq, logσq, μp, logσp), st2 = Lux.apply(model, (x, mask, ε), ps, st)
    recon = bce_with_logits_masked(logits, x, mask)
    kl = kl_diag_gaussians(μq, logσq, μp, logσp)
    (recon + kl), st2, (; recon, kl)
end

function bce_with_logits_masked(logits, x, mask)
    w = 1f0 .- mask
    per_elem = softplus.(logits) .- x .* logits
    s = sum(w .* per_elem)
    return s / size(x, 2)
end

function kl_diag_gaussians(μq, logσq, μp, logσp)
    σq2 = exp.(2f0 .* logσq)
    σp2 = exp.(2f0 .* logσp)
    t = (σq2 .+ (μq .- μp).^2) ./ σp2 .- 1f0 .+ 2f0 .* (logσp .- logσq)
    0.5f0 * sum(t) / size(μq, 2)
end

can you post the whole backtrace, and cc @avikpal

@avikpal
It’s here:


ERROR: MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
The type `Float32` exists, but no method is defined for this combination of argument types when trying to construct it.

Closest candidates are:
  (::Type{T})(::T) where T<:Number
   @ Core boot.jl:900
  Float32(::IrrationalConstants.Logπ)
   @ IrrationalConstants ~/.julia/packages/IrrationalConstants/lWTip/src/macro.jl:132
  Float32(::UInt8)
   @ Base float.jl:245
  ...

Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0 [inlined]
  [2] call_with_reactant(::Reactant.MustThrowError, ::Type{Float32}, ::Reactant.TracedRNumber{Float32})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:764
  [3] convert
    @ ./number.jl:7 [inlined]
  [4] convert(none::Type{Float32}, none::Reactant.TracedRNumber{Float32})
    @ Reactant ./<missing>:0
  [5] convert
    @ ./number.jl:7 [inlined]
  [6] call_with_reactant(::Reactant.MustThrowError, ::typeof(convert), ::Type{Float32}, ::Reactant.TracedRNumber{Float32})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
  [7] cvt1
    @ ./essentials.jl:612 [inlined]
  [8] ntuple
    @ ./ntuple.jl:49 [inlined]
  [9] convert
    @ ./essentials.jl:614 [inlined]
 [10] convert(none::Type{Tuple{…}}, none::Tuple{Reactant.TracedRNumber{…}, Reactant.TracedRNumber{…}})
    @ Reactant ./<missing>:0
 [11] call_with_reactant(::Reactant.MustThrowError, ::typeof(convert), ::Type{…}, ::Tuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:420
 [12] cvt1
    @ ./essentials.jl:612 [inlined]
 [13] ntuple
    @ ./ntuple.jl:50 [inlined]
 [14] ntuple(none::Base.var"#cvt1#1"{Tuple{…}, Tuple{…}}, none::Val{3})
    @ Reactant ./<missing>:0
 [15] cvt1
    @ ./essentials.jl:610 [inlined]
 [16] ntuple
    @ ./ntuple.jl:50 [inlined]
 [17] call_with_reactant(::typeof(ntuple), ::Base.var"#cvt1#1"{Tuple{…}, Tuple{…}}, ::Val{3})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [18] convert
    @ ./essentials.jl:614 [inlined]
 [19] setproperty!
    @ ./Base.jl:52 [inlined]
 [20] setproperty!(none::Optimisers.Leaf{…}, none::Symbol, none::Tuple{…})
    @ Reactant ./<missing>:0
 [21] setproperty!
    @ ./Base.jl:51 [inlined]
 [22] call_with_reactant(::Reactant.MustThrowError, ::typeof(setproperty!), ::Optimisers.Leaf{…}, ::Symbol, ::Tuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [23] #_update!#10
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:96 [inlined]
 [24] var"#_update!#10"(none::IdDict{…}, none::IdDict{…}, none::typeof(Optimisers._update!), none::Optimisers.Leaf{…}, none::Reactant.TracedRArray{…})
    @ Reactant ./<missing>:0
 [25] #_update!#10
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:93 [inlined]
 [26] call_with_reactant(::Optimisers.var"##_update!#10", ::IdDict{…}, ::IdDict{…}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [27] _update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:92 [inlined]
 [28] #8
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [29] map
    @ ./tuple.jl:383 [inlined]
 [30] map
    @ ./namedtuple.jl:266 [inlined]
 [31] mapvalue
    @ ~/.julia/packages/Optimisers/V8kHf/src/utils.jl:2 [inlined]
 [32] #_update!#7
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [33] _update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:81 [inlined]
 [34] #8
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [35] map
    @ ./tuple.jl:386 [inlined]
 [36] map(none::Optimisers.var"#8#9"{…}, none::Tuple{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [37] getindex
    @ ./tuple.jl:31 [inlined]
 [38] map
    @ ./tuple.jl:386 [inlined]
 [39] call_with_reactant(::typeof(map), ::Optimisers.var"#8#9"{…}, ::Tuple{…}, ::Tuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [40] map
    @ ./namedtuple.jl:266 [inlined]
 [41] map(none::Function, none::@NamedTuple{…}, none::Tuple{…})
    @ Reactant ./<missing>:0
 [42] call_with_reactant(::typeof(map), ::Function, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:764
--- the above 12 lines are repeated 1 more time ---
 [55] mapvalue
    @ ~/.julia/packages/Optimisers/V8kHf/src/utils.jl:2 [inlined]
 [56] #_update!#7
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:85 [inlined]
 [57] _update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:81 [inlined]
 [58] update!
    @ ~/.julia/packages/Optimisers/V8kHf/src/interface.jl:77 [inlined]
 [59] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/gmUbf/ext/LuxReactantExt/training.jl:90 [inlined]
 [60] compute_gradients_internal_and_step!(none::typeof(ProbAbEx.loss_fn), none::ProbAbEx.VAEAC{…}, none::Tuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…})
    @ Reactant ./<missing>:0
 [61] GenericMemory
    @ ./boot.jl:516 [inlined]
 [62] IdDict
    @ ./iddict.jl:31 [inlined]
 [63] IdDict
    @ ./iddict.jl:49 [inlined]
 [64] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/lmG5F/src/EnzymeCore.jl:587 [inlined]
 [65] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/gmUbf/ext/LuxReactantExt/training.jl:85 [inlined]
 [66] call_with_reactant(::typeof(LuxReactantExt.compute_gradients_internal_and_step!), ::typeof(ProbAbEx.loss_fn), ::ProbAbEx.VAEAC{…}, ::Tuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/xQ2Wo/src/utils.jl:0
 [67] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/xQ2Wo/src/TracedUtils.jl:192
 [68] make_mlir_fn
    @ ~/.julia/packages/Reactant/xQ2Wo/src/TracedUtils.jl:94 [inlined]
 [69] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool, no_nan::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:373
 [70] compile_mlir!
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:364 [inlined]
 [71] compile_xla(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool, no_nan::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:894
 [72] compile_xla
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:884 [inlined]
 [73] compile(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool, sync::Bool, no_nan::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:922
 [74] compile
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:921 [inlined]
 [75] macro expansion
    @ ~/.julia/packages/Reactant/xQ2Wo/src/Compiler.jl:613 [inlined]
 [76] single_train_step_impl!(backend::Lux.Training.ReactantBackend, objective_function::typeof(ProbAbEx.loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxReactantExt ~/.julia/packages/Lux/gmUbf/ext/LuxReactantExt/training.jl:39
 [77] single_train_step!(backend::ADTypes.AutoEnzyme{…}, obj_fn::typeof(ProbAbEx.loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/gmUbf/src/helpers/training.jl:275
 [78] train_vaeac(; epochs::Int64, lr::Float32, batch_size::Int64)
    @ ProbAbEx ~/ProbAbEx/src/samplers/VAEAC_training.jl:135`Preformatted text`

Can you post the full script, I can run locally?

Thank you for your desire to help. The problem was solved by clearing all packages and downloading all project dependencies again.