How to use a custom loss function with reactant?

I’m trying to make this code work,

using MLDatasets,Lux,Reactant,Enzyme,Random,OneHotArrays,MLUtils,Optimisers,Statistics
Reactant.set_default_backend("gpu")

const rng = Random.MersenneTwister(123)
const dev = xla_device()


# Load the MNIST dataset
mnist_train = MLDatasets.MNIST(:train)
mnist_test = MLDatasets.MNIST(:test)

# Make A MLP model
function make_mlp(num_layer::Int,input_dim ::Int,hidden_dim::Int,output_dim::Int)
    model = Chain(
        Dense(input_dim,hidden_dim,relu),
        [Dense(hidden_dim,hidden_dim,relu) for i in 1:num_layer-2]...,
        Dense(hidden_dim,output_dim)
    )
    ps,st = Lux.setup(rng,model) |> dev
    return model,ps,st
end

function loss_fn(model,ps,st,d)
    x,y = d
    ŷ,_ = model(x,ps,st)
    return Lux.MSELoss()(ŷ,y)
end

function main()
    num_layers = 2
    hidden_dim = 32
    num_classes = 10
    batch_size = 256
    num_epochs = 32
    learning_rate = 1e-3
    imgs_train = reshape(mnist_train.features,
        (
            size(mnist_train.features,1)*size(mnist_train.features,2),
            size(mnist_train.features,3)
        )
        )
    imgs_test = reshape(mnist_test.features,
        (
            size(mnist_test.features,1)*size(mnist_test.features,2),
            size(mnist_test.features,3)
        )
        )
    class = 0:9
    targets_train = onehotbatch(mnist_train.targets,class)

    model,ps,st = make_mlp(num_layers,28*28,hidden_dim,num_classes)

    dl_train = DataLoader((imgs_train,targets_train),batchsize=batch_size,shuffle=true,partial=false)|> dev
    opt = Training.TrainState(model,ps,st,Adam(learning_rate))
    for epoch in 1:num_epochs
        for (i,(x,y)) in enumerate(dl_train)
            _,loss,_,opt = Training.single_train_step!(
                AutoEnzyme(),
                loss_fn,
                (x,y),
                opt
            )
            if i==1
                @info "Epoch: $epoch, Loss: $loss"
            end
        end
    end
    ps = opt.parameters
    targets = mnist_test.targets
    modelcomp = @compile model(dev(imgs_test),ps,st)
    ŷ,_ = modelcomp(dev(imgs_test),ps,st)
    ŷ = argmax.(eachcol(Array(ŷ))) .- 1
    acc = mean(ŷ .== targets)*100
    @info "Accuracy: $acc"
    return opt
end

opt = main()

It crash at this line

 _,loss,_,opt = Training.single_train_step!(
                AutoEnzyme(),
                loss_fn,
                (x,y),
                opt
            )

but works if I simply do

 _,loss,_,opt = Training.single_train_step!(
                AutoEnzyme(),
                MSELoss(),
                (x,y),
                opt
            )

I followed this tuto https://lux.csail.mit.edu/stable/manual/compiling_lux_models

trunk error :

BoundsError: attempt to access Reactant.TracedRNumber{Float32} at index [2]
Stacktrace:
  [1] indexed_iterate
    @ ./tuple.jl:101 [inlined]
  [2] indexed_iterate(none::Reactant.TracedRNumber{Float32}, none::Int64, none::Nothing)
    @ Reactant ./<missing>:0
  [3] indexed_iterate
    @ ./tuple.jl:101 [inlined]
  [4] call_with_reactant(::typeof(Base.indexed_iterate), ::Reactant.TracedRNumber{Float32}, ::Int64, ::Nothing)
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
  [5] wrapped_objective_function
    @ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:9 [inlined]
  [6] wrapped_objective_function(none::typeof(loss_fn), none::Chain{…}, none::@NamedTuple{…}, none::@NamedTuple{…}, none::Tuple{…}, none::LuxReactantExt.StatsAndNewStateWrapper)
    @ Reactant ./<missing>:0
  [7] wrapped_objective_function
    @ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:9 [inlined]
  [8] call_with_reactant(::typeof(LuxReactantExt.wrapped_objective_function), ::typeof(loss_fn), ::Chain{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::Tuple{…}, ::LuxReactantExt.StatsAndNewStateWrapper)
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
  [9] (::Reactant.TracedUtils.var"#8#18"{…})()
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:182
 [10] block!(f::Reactant.TracedUtils.var"#8#18"{…}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
 [11] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:169
 [12] make_mlir_fn
    @ ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:86 [inlined]
 [13] overload_autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Duplicated{…}, ::Const{…}, ::Const{…}, ::Const{…})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/Interpreter.jl:238
 [14] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Duplicated{…}, ::Const{…}, ::Const{…}, ::Const{…})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/Overlay.jl:32
 [15] macro expansion
    @ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:275 [inlined]
 [16] gradient
    @ ~/.julia/packages/Enzyme/DiEvV/src/sugar.jl:263 [inlined]
 [17] compute_gradients_internal
    @ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:17 [inlined]
 [18] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:115 [inlined]
 [19] compute_gradients_internal_and_step!(none::typeof(loss_fn), none::Chain{…}, none::Tuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…}, none::@NamedTuple{…})
    @ Reactant ./<missing>:0
 [20] compute_gradients_internal
    @ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:16 [inlined]
 [21] compute_gradients_internal_and_step!
    @ ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:115 [inlined]
 [22] call_with_reactant(::typeof(LuxReactantExt.compute_gradients_internal_and_step!), ::typeof(loss_fn), ::Chain{…}, ::Tuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::@NamedTuple{…})
    @ Reactant ~/.julia/packages/Reactant/WudhJ/src/utils.jl:0
 [23] (::Reactant.TracedUtils.var"#8#18"{…})()
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:182
 [24] block!(f::Reactant.TracedUtils.var"#8#18"{…}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
 [25] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
    @ Reactant.TracedUtils ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:169
 [26] make_mlir_fn
    @ ~/.julia/packages/Reactant/WudhJ/src/TracedUtils.jl:86 [inlined]
 [27] #10
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:319 [inlined]
 [28] block!(f::Reactant.Compiler.var"#10#15"{…}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Block.jl:201
 [29] #9
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:318 [inlined]
 [30] mmodule!(f::Reactant.Compiler.var"#9#14"{…}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Module.jl:92
 [31] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:315
 [32] compile_mlir!
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:314 [inlined]
 [33] (::Reactant.Compiler.var"#32#34"{Bool, typeof(LuxReactantExt.compute_gradients_internal_and_step!), Tuple{…}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:799
 [34] context!(f::Reactant.Compiler.var"#32#34"{…}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/WudhJ/src/mlir/IR/Context.jl:76
 [35] compile_xla(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:796
 [36] compile_xla
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:791 [inlined]
 [37] compile(f::Function, args::Tuple{…}; client::Nothing, optimize::Bool, sync::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:823
 [38] compile
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:822 [inlined]
 [39] macro expansion
    @ ~/.julia/packages/Reactant/WudhJ/src/Compiler.jl:536 [inlined]
 [40] single_train_step_impl!(backend::Lux.Training.ReactantBackend, objective_function::typeof(loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ LuxReactantExt ~/.julia/packages/Lux/fMnM0/ext/LuxReactantExt/training.jl:79
 [41] single_train_step!(backend::AutoEnzyme{…}, obj_fn::typeof(loss_fn), data::Tuple{…}, ts::Lux.Training.TrainState{…})
    @ Lux.Training ~/.julia/packages/Lux/fMnM0/src/helpers/training.jl:276
 [42] main()
    @ Main ./REPL[9]:29
 [43] top-level scope
    @ REPL[10]:1

nevermind sorry, a loss should be define like this

function loss_fn(model,ps,st,d)
    x,y = d
    ŷ,stn = model(x,ps,st)
    return Lux.MSELoss()(ŷ,y),stn,(;)
end
2 Likes