One follow-up question: How would I use OptimiserChain
together with Lux
? In this toy example:
model = Chain(
Dense(2, 4, tanh),
Dense(4, 1),
)
rng = Random.default_rng()
Random.seed!(rng, 42)
X_train = randn(rng, Float32, 2,100)
y_train = randn(rng, Float32, 1,100)
ps, st = LuxCore.setup(rng, model)
function train_model!(model, ps, st, x, y)
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.01f0))
for iter in 1:1000
_, loss, _, train_state = Lux.Training.single_train_step!(
AutoEnzyme(),
OptimiserChain(MSELoss(), WeightDecay(0.001)),
(x, y), train_state
)
if iter % 100 == 1 || iter == 1000
@printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
end
end
return model, ps, st
end
train_model!(model, ps, st, X_train, y_train)
I get the following error:
MethodError: objects of type OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}} are not callable
The object of type `OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.
Stacktrace:
[1] generate_wrappers(objective_function::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, m::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ::Static.True)
@ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:247
[2] wrap_objective_function(objective_function::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, m::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, first_try::Static.True)
@ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:253
[3] compute_gradients_impl(ad::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
@ LuxEnzymeExt ~/.julia/packages/Lux/DHtyL/ext/LuxEnzymeExt/training.jl:5
[4] compute_gradients(ad::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
@ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:200
[5] single_train_step_impl!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
@ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:320
[6] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}; return_gradients::Static.True)
@ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:288
[7] single_train_step!(backend::AutoEnzyme{Nothing, Nothing}, obj_fn::OptimiserChain{Tuple{GenericLossFunction{typeof(Lux.LossFunctionImpl.l2_distance_loss), typeof(Statistics.mean)}, WeightDecay}}, data::Tuple{Matrix{Float32}, Matrix{Float32}}, ts::Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, Adam, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}})
@ Lux.Training ~/.julia/packages/Lux/DHtyL/src/helpers/training.jl:284
[8] train_model!(model::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, ps::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, x::Matrix{Float32}, y::Matrix{Float32})
@ Main ~/MEGA/projects/Blog/KAN/Julia_implementation/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X24sZmlsZQ==.jl:6
[9] top-level scope
@ ~/MEGA/projects/Blog/KAN/Julia_implementation/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X24sZmlsZQ==.jl:19
Thanks for your help! 