How to use Flux to write a Multi-Head Output Network

What AI doesn’t seem to know is that you need to tell Flux to look for parameters inside, by making the layer with a macro. Otherwise it will not be able to train:

julia> Flux.setup(Adam(), PolicyNetwork(2, 3, 4))
┌ Warning: setup found no trainable parameters in this model
└ @ Optimisers ~/.julia/packages/Optimisers/yDIWk/src/interface.jl:32
()

julia> Flux.@layer PolicyNetwork  # Defines methods for functor, show

julia> Flux.setup(Adam(), PolicyNetwork(2, 3, 4))  # now this sees parameters
(hidden1 = (layers = ((weight = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0; 0.0 0.0;

After this, PolicyNetwork(2, 3, 4) |> gpu will move all the parameters, so your second definition should not be needed.

Note also that adding type parameters might be a good idea, for performance:

struct PolicyNetwork{A,B,C,D}
    hidden1::A
    hidden2::B
    mu::C
    std::D
end
2 Likes