Training layers of a Flux model separately

Suppose I have a model in Flux, which is a bunch of layers that are stacked using Flux.chain. The training algorithm that I have trains layers individually.
So would the following (simplified) code work?

layers = [Vector of Layers]
model = Flux.chain(layers)
for layer in layers
    gs = Flux.gradient(Flux.params(layer)) do
             some_loss_fn(some_inp, some_out)
    end
    Flux.update!(optimizer, Flux.params(layer), gs)
end

I think that is pretty close :slight_smile: But you probably want to work with the model rather than the array of layers. I modified your code to make a working example. Since I don’t know how you perform inference or what your layerwise losses look like I just assumed some dummy targets and losses.

using Flux

nodes = [16 14 13 15 4]
layers = [Dense(nodes[i], nodes[i+1]) for i=1:length(nodes)-1]
model = Chain(layers...) #splat the layers
opt = Descent(0.01)
batchsize = 64

#= Generate a batch of dummy data. z[0] is the input z[i>0] are some layerwise targets =#
z = [rand(Float32, n, batchsize) for n in nodes]

myloss(x, y) = sum(abs2, x-y)

function train(model, z, opt, loss)
    for (i, layer) in enumerate(model)
        gs = Flux.gradient(Flux.params(layer)) do
            y = layer(z[i])    
            loss(y, z[i+1])
        end
        Flux.update!(opt, Flux.params(layer), gs)
    end
end

train(model, z, opt, myloss)
1 Like