I am very interested in the example given in the doc Freezing Model Parameters | LuxDL Docs . However, in that example, the parameters are obtained from
ps, st = Lux.setup(rng, model_frozen)
I find there is no convenient way to retain the parameters of the old model. Say I have a pre-trained model, which is a Chain
of nested Dense
and GRU
layers. The trained parameters and states are ps_trained
and st_trained
. How can I use the method provided in the doc to use ps_trained
and st_trained
?
avikpal
October 30, 2023, 10:43pm
2
Check the sections just before that. It should how to use Experimental Features | LuxDL Docs to accomplish exactly what you asked for
Yeah, I have tried @layer_map
before. However, it seems not working properly with Recurrence(cell = GRUCell(32 => 32))
layer. The error message is
ERROR: AssertionError: fieldnames(typeof(st_c)) == fieldnames(typeof(ps_c))
Stacktrace:
[1] layer_map(f::typeof(freeze_except_last_layer), l::Recurrence{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
@ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:96
[2] layer_map(f::typeof(freeze_except_last_layer), l::@NamedTuple{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
@ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
[3] layer_map(f::typeof(freeze_except_last_layer), l::Lux.Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
@ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
[4] layer_map(f::typeof(freeze_except_last_layer), l::SkipConnection{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
@ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
[5] layer_map(f::typeof(freeze_except_last_layer), l::@NamedTuple{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
@ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
[6] layer_map(f::typeof(freeze_except_last_layer), l::Lux.Chain{…}, ps::@NamedTuple{…}, st::@NamedTuple{…}, name::String)
@ Lux.Experimental ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:108
[7] top-level scope
@ ~/.julia/packages/Lux/Al3Ab/src/contrib/map.jl:38
Some type information was truncated. Use `show(err)` to see complete types.
I believe there is a bug in @layer_map
.
Can you open an issue with a reproducer?