Suppose I have a model in Flux, which is a bunch of layers that are stacked using Flux.chain
. The training algorithm that I have trains layers individually.
So would the following (simplified) code work?
layers = [Vector of Layers]
model = Flux.chain(layers)
for layer in layers
gs = Flux.gradient(Flux.params(layer)) do
some_loss_fn(some_inp, some_out)
end
Flux.update!(optimizer, Flux.params(layer), gs)
end
I think that is pretty close But you probably want to work with the model rather than the array of layers. I modified your code to make a working example. Since I don’t know how you perform inference or what your layerwise losses look like I just assumed some dummy targets and losses.
using Flux
nodes = [16 14 13 15 4]
layers = [Dense(nodes[i], nodes[i+1]) for i=1:length(nodes)-1]
model = Chain(layers...) #splat the layers
opt = Descent(0.01)
batchsize = 64
#= Generate a batch of dummy data. z[0] is the input z[i>0] are some layerwise targets =#
z = [rand(Float32, n, batchsize) for n in nodes]
myloss(x, y) = sum(abs2, x-y)
function train(model, z, opt, loss)
for (i, layer) in enumerate(model)
gs = Flux.gradient(Flux.params(layer)) do
y = layer(z[i])
loss(y, z[i+1])
end
Flux.update!(opt, Flux.params(layer), gs)
end
end
train(model, z, opt, myloss)
1 Like