Modify existing Lux models

Is there a systematic approach to obtain a modified version of an existing Lux model?

The documentation of the Lux layer interface is clear: “once constructed a model architecture cannot change.” However, we may want to replace or modify parts of an existing, possibly pretrained Lux model.

An example of replacement could be using the backbone of a vision model as a feature extractor while replacing the classifier layers to address a different task. An example of modification could be this proposed solution to adapt a vision model trained on RGB images to work on grayscale images (adapting both the architecture and the pretrained parameters of the input layer).

I see that I can manually create a new Lux model with the required changes. In the minimal working example below, model_prime is like model but changing the second layer (both architecture and parameters):

using Lux
using Random

rng = Random.default_rng()
Random.seed!(rng, 0)

model = Lux.Chain(
    Lux.Dense(4, 3),
    Lux.Dense(3, 1)
)

ps, st = Lux.setup(rng, model)

x = randn(rng, Float32, 4, 1);
@show size(model(x, ps, st)[1])

model_prime = Lux.Chain(
    Lux.Dense(4, 3),
    Lux.Dense(3, 6)
)

ps_prime = merge(
    ps, 
    (layer_2 = (weight = randn(rng, Float32, 6, 3), bias = randn(rng, Float32, 6, 1)), )
)

@show size(model_prime(x, ps_prime, st)[1])

However, this seems inconvenient and even error-prone as soon as the model becomes more complex. To achieve the goal in a more systematic manner, I could think of copying the model and modifying only the layers of interest. However, in my minimal working example the following code

model_copy = deepcopy(model)
model_copy.layers.layer_2 = Lux.Dense(3, 6)

throws an ERROR: setfield!: immutable struct of type NamedTuple cannot be changed.

I would appreciate advice. Should I investigate how Lux models are defined under the hood to find a mechanism? Is there an obvious way to achieve this goal?

Use Setfield.jl

@set! model.layers.layer_2 = Dense(3 => 6)

Same thing can be done for the parameters.

1 Like

Hi @avikpal I think my question is relevant to this topic so I’d like to ask it here to see whether what I’ve done is reasonable. My aim is to implement a RNN with return_sequence=True and its input and output should be preprocessed and post-processed by Dense layers, like the simple model below:

function GRU_Seq2Seq(nfeatures; gru_out_dim=GRU_OUT, dense_act=DENSE_ACT)
    dense_in1 = Lux.Dense(nfeatures=>gru_out_dim*2, dense_act)
    drop1 = Lux.Dropout(0.5)
    dense_in2 = Lux.Dense(gru_out_dim*2=>gru_out_dim, dense_act)
    gru = Lux.Recurrence(Lux.GRUCell(gru_out_dim=>gru_out_dim); return_sequence=true)
    stack_layer = StackLayer(2)
    dense_out1 = Lux.Dense(gru_out_dim=>16, dense_act)
    drop2 = Lux.Dropout(0.2)
    dense_out2 = Lux.Dense(16=>1, identity)
    return Lux.Chain(dense_in1, drop1, dense_in2, gru, stack_layer, dense_out1, drop2, dense_out2)
end

where the StackLayer is defined to make the RNN output suitable for Dense layers:

struct StackLayer <: Lux.AbstractExplicitLayer
    dims::Int
end

@inline function (s::StackLayer)(x::AbstractArray, ps, st::NamedTuple)
    return stack(x, dims=s.dims), st
end

I’d like to know is there any idiomatic way to define such Seq2Seq RNN model without hacking my own StackLayer. Do I miss something from the Lux.jl doc? I feel it is a natural thing to do a post processing by Dense layers.

Moreover, to test the performance of my seq2seq model, I also implement a seq2last model which only return the last value like following:

function GRU_Seq2Last(nfeatures; gru_out_dim=GRU_OUT, dense_act=DENSE_ACT)
    dense_in1 = Lux.Dense(nfeatures=>gru_out_dim*2, dense_act)
    drop1 = Lux.Dropout(0.5)
    dense_in2 = Lux.Dense(gru_out_dim*2=>gru_out_dim, dense_act)
    gru = Lux.Recurrence(Lux.GRUCell(gru_out_dim=>gru_out_dim))
    dense_out1 = Lux.Dense(gru_out_dim=>16, dense_act)
    drop2 = Lux.Dropout(0.2)
    dense_out2 = Lux.Dense(16=>1, identity)
    return Lux.Chain(dense_in1, drop1, dense_in2, gru, dense_out1, drop2, dense_out2)
end

Now that the GRU_Seq2Last is trained, I’d like use it to evaluate the test data but return a full sequence instead of the last value as in the training phase. Currently, I use a really ugly hack to transform GRU_Seq2Last parameters and states to be usable by GRU_Seq2Seq and use this model to predict. The codes are:

function seq2last_to_seq2seq(ps, st)
    ps_out = (layer_1=ps.layer_1, layer_2=ps.layer_2,
              layer_3=ps.layer_3, layer_4=ps.layer_4,
              layer_5=NamedTuple(),
              layer_6=ps.layer_5, layer_7=ps.layer_6,
              layer_8=ps.layer_7)
    st_out = (layer_1=st.layer_1, layer_2=st.layer_2,
              layer_3=st.layer_3, layer_4=st.layer_4,
              layer_5=NamedTuple(),
              layer_6=st.layer_5, layer_7=st.layer_6,
              layer_8=st.layer_7)
    return ps_out, st_out
end

# ps_seq2last, st_seq2last is trained by GRU_Seq2Last
ps_seq2seq, st_seq2seq = seq2last_to_seq2seq(ps_seq2last, ps_seq2last)
yp_seq2seq = GRU_Seq2Seq(40)(X, ps_seq2seq, st_seq2seq)

The Setfield.jl seems can’t help here since I want to insert a layer (the StackLayer) instead of replace it.

Thank you, @avikpal. This solves my question. Let me add though that, after reading the README of Setfield.jl, I have tried your proposed solution but only with the successor package Accessors.jl, not with Setfield.jl.

Great! Both should work, but I run tests only with Setfield so I recommended that.

1 Like