How do I retain model parameters while freezing layers with Lux.jl

I am very interested in the example given in the doc Freezing Model Parameters | LuxDL Docs. However, in that example, the parameters are obtained from

ps, st = Lux.setup(rng, model_frozen)

I find there is no convenient way to retain the parameters of the old model. Say I have a pre-trained model, which is a Chain of nested Dense and GRU layers. The trained parameters and states are ps_trained and st_trained. How can I use the method provided in the doc to use ps_trained and st_trained?

Check the sections just before that. It should how to use Experimental Features | LuxDL Docs to accomplish exactly what you asked for

Yeah, I have tried @layer_map before. However, it seems not working properly with Recurrence(cell = GRUCell(32 => 32)) layer. The error message is

ERROR: AssertionError: fieldnames(typeof(st_c)) == fieldnames(typeof(ps_c))
Stacktrace:
 [1] layer_map(f::typeof(freeze_except_last_layer), l::Recurrence{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
   @ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:96
 [2] layer_map(f::typeof(freeze_except_last_layer), l::@NamedTuple{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
   @ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
 [3] layer_map(f::typeof(freeze_except_last_layer), l::Lux.Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
   @ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
 [4] layer_map(f::typeof(freeze_except_last_layer), l::SkipConnection{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
   @ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
 [5] layer_map(f::typeof(freeze_except_last_layer), l::@NamedTuple{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
   @ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
 [6] layer_map(f::typeof(freeze_except_last_layer), l::Lux.Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
   @ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
 [7] top-level scope
   @ ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:38
Some type information was truncated. Use `show(err)` to see complete types.

I believe there is a bug in @layer_map.

Can you open an issue with a reproducer?