Custom loss functions in `Lux.jl`

I am currently getting accustomed to the Lux.jl package. I want to use a custom loss function which includes a l_2-regression term consisting of the model weights. I defined my custom loss function as

function loss_function(model, ps, st, (x, y))
    T = eltype(Base.Flatten(first(ps)))

    loss_mse = MSELoss()(model, ps, st, (x,y))[1]

    loss_reg = zero(T)
    for p in ps
        loss_reg += sum(abs2, Base.Flatten(p))
    end

    loss_total = loss_mse + convert(T, 0.001) * loss_reg

    return loss_total, st, NamedTuple()
end

I’d like to get some feedback of more experienced users of Lux.jl. Is this a good approach? It works with basic models in Lux.jl. My questions:

  • Is calling MSELoss() inside the function a good idea or better move it outside loss_function?
  • I use Enzyme.jl as AD backend, which likes type-stability. I struggled a bit to figure out the return type. Is my approach appropriate?

Thanks for the responses and those awesome packages!

Is calling MSELoss() inside the function a good idea or better move it outside loss_function?

Either is fine. If you are moving it to a global scope remember to annotate it with const

For L2 regression term on model weights, it will be more efficient to do API · Optimisers.jl as part of your optimizer

loss_reg = zero(T)
for p in ps
loss_reg += sum(abs2, Base.Flatten(p))
end

This won’t work in the general case. If you really want to use this approach (instead of the optimizers way), use Functors.fleaves(ps) and then iterate over it. But this method is type unstable.

1 Like

Thanks for the fast and concise response! WeightDecay is the way to go :slight_smile:

One follow-up question: How would I use OptimiserChain together with Lux? In this toy example:

model = Chain(
    Dense(2, 4, tanh),
    Dense(4, 1),
)

rng = Random.default_rng()
Random.seed!(rng, 42)

X_train = randn(rng, Float32, 2,100)
y_train = randn(rng, Float32, 1,100)

ps, st = LuxCore.setup(rng, model)

function train_model!(model, ps, st, x, y)

    train_state = Lux.Training.TrainState(model, ps, st, Adam(0.01f0))

    for iter in 1:1000
        _, loss, _, train_state = Lux.Training.single_train_step!(
            AutoEnzyme(),
            OptimiserChain(MSELoss(), WeightDecay(0.001)),
            (x, y), train_state
        )
        if iter % 100 == 1 || iter == 1000
            @printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
        end
    end

    return model, ps, st
end

train_model!(model, ps, st, X_train, y_train)

I get the following error:

MethodError: objects of type OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}} are not callable
The object of type `OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Stacktrace:
 [1] generate_wrappers(objective_function::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, m::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ::Static.True)
   @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:247
 [2] wrap_objective_function(objective_function::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, m::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, first_try::Static.True)
   @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:253
 [3] compute_gradients_impl(ad::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
   @ LuxEnzymeExt ~/.julia/packages/Lux/DHtyL/ext/LuxEnzymeExt/training.jl:5
 [4] compute_gradients(ad::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
   @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200
 [5] single_train_step_impl!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
   @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:320
 [6] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}; return_gradients::Static.True)
   @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:288
 [7] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
   @ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:284
 [8] train_model!(model::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, x::Matrix{Float32}, y::Matrix{Float32})
   @ Main ~/MEGA/projects/Blog/KAN/Julia_implementation/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X24sZmlsZQ==.jl:6
 [9] top-level scope
   @ ~/MEGA/projects/Blog/KAN/Julia_implementation/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X24sZmlsZQ==.jl:19

Thanks for your help! :slight_smile:

train_state = Lux.Training.TrainState(model, ps, st, OptimisersChain(Adam(0.01f0), WeightDecay(....))

instead of chaining the loss function

1 Like

Thanks that works!