`apply` do not work anymore with StatefulLuxLayer (by Luxv1.19)

using Lux, Random

nn = Lux.Dense(10, 10, tanh)

ps, st = Lux.setup(Xoshiro(2024), nn)
    
nn_st = Lux.StatefulLuxLayer{true}(nn, nothing, st, nothing) 

input_data = rand(10)
Lux.apply(nn_st, input_data, ps) # do not work on Luxv1.21.0 but worked on Luxv1.18.0
# ERROR: MethodError: no method matching apply(::StatefulLuxLayer{…}, ::Vector{…}, ::@NamedTuple{…})
# The function `apply` exists, but no method is defined for this combination of argument types.

# Closest candidates are:
#   apply(::AbstractLuxLayer, ::Any, ::Any, ::Any)

nn_st(input_data, ps) # works in both cases

Since v1.19 StatefulLuxLayer have been moved to LuxCore and documentation says this is not an AbstractLuxLayer.
Since that update, the apply do no longer work. Only nn_st(x, ps) works, but I thought it was better to use apply (see docs).

Since I just want to use StatefulLuxLayer to keep the state while changing parameters, Lux.StatefulLuxLayer{true}(nn, nothing, st, nothing) with apply seemed like the best option.

What is the best post-Luxv1.19 way to do that?

BTW I don’t know what the 2nd nothing refers to (it is a code from one of my old student) in the StatefullLuxlayer def :roll_eyes:.

Removing the apply dispatch was accidental. I will patch it (fix: accidental apply dispatch removal by avik-pal · Pull Request #1476 · LuxDL/Lux.jl · GitHub) in the next release (few hrs) (though that dispatch was never meant to be publicly used :sweat_smile:, the nn_st(input_data) is the correct one)

Try this

using Setfield, Lux

@set! nn_st.ps = ps_new

Thanks!
So looking at the PR code change, it seems that

apply(nn_st, x, ps_new) 
# or
nn_st(x, ps_new)
# or
@set! nn_st.ps = ps_new # and
nn_st(x)

are all doing the same exact thing so I can pick the syntax I prefer, there are no penalties changing parameters one way or another (correct?)

Correct there are no penalties.

nn_st(x, ps_new) comes from the SciML land and is the widely used one, so I would recommend this one, but all others are equivalent

I just had another bug introduced with the update.
This feels more like a breaking change so it might be expected but in case not I just wanted to share.

I had saved some models

julia> typeof(nn_st_t)
StatefulLuxLayer{Static.True, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Serialization.__deserialized_types__.var"#5#6", Nothing, @NamedTuple{}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}, Nothing, @NamedTuple{}}

In that case nn_st_t(x) does not work.

Apparently the fixed_state_type type must now be written as

julia>  @set! nn_st_t.fixed_state_type = Val{true}()
julia> typeof(nn_st_t)
StatefulLuxLayer{Val{true}, CompactLuxLayer{:₋₋₋no_special_dispatch₋₋₋, Serialization.__deserialized_types__.var"#5#6", Nothing, @NamedTuple{}, Lux.CompactMacroImpl.ValueStorage{@NamedTuple{}, @NamedTuple{}}, Tuple{Tuple{}, Tuple{}}}, Nothing, @NamedTuple{}}

which solves the issue.

Did you serialize the StatefulLuxLayer? Generally models are not guaranteed to be stable under serialization and you should save the parameters and states (which are guaranteed to not change in the same major version similar to pytorch’s state_dict()). Training a Simple LSTM | Lux.jl Docs