Hi,
I’m trying to train neural networks with Lux, Reactant, and the Training API and am confused about how to tie it together.
When I’m compiling a model with reactant, should I still use the TrainState API?
I tried passing a compiled model into TrainState and get this error:
julia> tstate = Training.TrainState(model_compiled, ps_ra, st_ra, opt)
ERROR: MethodError: no method matching Lux.Training.TrainState(::Reactant.Compiler.Thunk{…}, ::@NamedTuple{…}, ::@NamedTuple{…}, ::Adam{…})
The type `Lux.Training.TrainState` exists, but no method is defined for this combination of argument types when trying to construct it.
Closest candidates are:
Lux.Training.TrainState(::__T_cache, ::__T_objective_function, ::__T_model, ::__T_parameters, ::__T_states, ::__T_optimizer, ::__T_optimizer_state, ::Int64) where {__T_cache, __T_objective_function, __T_model, __T_parameters, __T_states, __T_optimizer, __T_optimizer_state}
@ Lux ~/.julia/packages/ConcreteStructs/7Lv7u/src/ConcreteStructs.jl:142
Lux.Training.TrainState(::AbstractLuxLayer, ::Any, ::Any, ::AbstractRule)
@ Lux ~/.julia/packages/Lux/uEbqO/src/helpers/training.jl:122
Stacktrace:
[1] top-level scope
@ REPL[48]:1
Some type information was truncated. Use `show(err)` to see complete types.
The docs for TrainState say that the version is not stable and I should use the version with the Optimisers.jl API. Where can I find this API? The tutorials, f.ex. the MLP fitting, still refer to this one.
Hi, you need to use the default Lux model (not compiled) for constructing the train state. The model will be compiled (and cached) when you call single_train_step (or compute_gradients) the first time. The hint on in the TrainState docs is referring to creating the TrainState by its default (struct) constructor, i.e. TrainState(fields…) (at least that is how I undertand it). The API you’re using (TrainState(model, ps, st, opt)) is the Optimiser.jl API already.
This seems to be a common mistake and from the next release of Lux, users will see the following error when attempting to pass a compiled function in TrainState
julia> train_state = Training.TrainState(model_compiled, ps, st, Adam())
ERROR: ArgumentError: Invalid TrainState construction using a compiled function.
`TrainState` is being constructed with a reactant compiled function, i.e. a
`Reactant.Compiler.Thunk`. This is likely a mistake as the model should be
passed in directly without being compiled first.
This is likely originating from the following style of usage:
```julia
using Lux, Reactant, Random, Optimisers
rdev = reactant_device()
model = Dense(10, 10)
ps, st = Lux.setup(Random.default_rng(), model) |> rdev
x = rand(10) |> rdev
model_compiled = @compile model(x, ps, st)
train_state = Training.TrainState(model_compiled, ps, st, Adam())
```
Instead avoid compiling the model and pass it directly to `TrainState`. When
`single_train_step` or other function are called on the `TrainState`, the
model will be compiled automatically.
```julia
train_state = Training.TrainState(model, ps, st, Adam())
```
For end-to-end usage example refer to the documentation:
<https://lux.csail.mit.edu/stable/manual/compiling_lux_models#compile_lux_model_trainstate>
Stacktrace:
[1] Lux.Training.TrainState(::Reactant.Compiler.Thunk{…}, ps::@NamedTuple{…}, st::@NamedTuple{}, optimizer::Adam{…})
@ LuxReactantExt /mnt/software/lux/Lux.jl/ext/LuxReactantExt/training.jl:5
[2] top-level scope
@ REPL[42]:1
[3] top-level scope
@ none:1
Some type information was truncated. Use `show(err)` to see complete types.
Awesome, thank you. A suggestion on the error message - put the correct code on top. Something like
When passing a model into TrainState make sure it is not compiled. When
`single_train_step` or other function are called on the `TrainState`, the
model will be compiled automatically.
```julia
using Lux, Reactant, Random, Optimisers
rdev = reactant_device()
model = Dense(10, 10)
ps, st = Lux.setup(Random.default_rng(), model) |> rdev
x = rand(10) |> rdev
train_state = Training.TrainState(model, ps, st, Adam())
Do not pass a compiled model into Training.TrainState
model_compiled = @compile model(x, ps, st)
train_state = Training.TrainState(model_compiled, ps, st, Adam())
-> This throws an error