Is there a systematic approach to obtain a modified version of an existing Lux model?
The documentation of the Lux layer interface is clear: “once constructed a model architecture cannot change.” However, we may want to replace or modify parts of an existing, possibly pretrained Lux model.
An example of replacement could be using the backbone of a vision model as a feature extractor while replacing the classifier layers to address a different task. An example of modification could be this proposed solution to adapt a vision model trained on RGB images to work on grayscale images (adapting both the architecture and the pretrained parameters of the input layer).
I see that I can manually create a new Lux model with the required changes. In the minimal working example below, model_prime
is like model
but changing the second layer (both architecture and parameters):
using Lux
using Random
rng = Random.default_rng()
Random.seed!(rng, 0)
model = Lux.Chain(
Lux.Dense(4, 3),
Lux.Dense(3, 1)
)
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 4, 1);
@show size(model(x, ps, st)[1])
model_prime = Lux.Chain(
Lux.Dense(4, 3),
Lux.Dense(3, 6)
)
ps_prime = merge(
ps,
(layer_2 = (weight = randn(rng, Float32, 6, 3), bias = randn(rng, Float32, 6, 1)), )
)
@show size(model_prime(x, ps_prime, st)[1])
However, this seems inconvenient and even error-prone as soon as the model becomes more complex. To achieve the goal in a more systematic manner, I could think of copying the model and modifying only the layers of interest. However, in my minimal working example the following code
model_copy = deepcopy(model)
model_copy.layers.layer_2 = Lux.Dense(3, 6)
throws an ERROR: setfield!: immutable struct of type NamedTuple cannot be changed
.
I would appreciate advice. Should I investigate how Lux models are defined under the hood to find a mechanism? Is there an obvious way to achieve this goal?