Yes, that is the general idea, but there are a few tweaks needed.
Flux.train! you need to supply the parameters you want to update. This is conveniently done using
Flux.params for all layers in flux as well as for
Chains which only have “raw” layers in them.
Unfortunately, when layers are wrapped in functions like your
m above they are no longer accessible in this way. Another issue is that
m will create a new
Conv every time it is called, so whatever parameters the layer have will be discarded the next time the function is called.
Here is one possible way of how to work around this:
julia> function createmodel(convlayer, denselayer)
x2 = convlayer(x) |> flatten
x1 = cat(x, x2, dims=1)
end, params(convlayer, denselayer)
createmodel (generic function with 1 method)
julia> m, ps = createmodel(Conv((2,2), 1=>1), Dense(371, 10));
julia> Flux.train!((x, y) -> Flux.mse(m(x), y), ps, [(((ones(10,2), rand(20,20,1,2))), ones(10, 2))], Descent())
As you can see,
createmodel returns not just the function to be optimized but also its parameters. You could also just as well remove that part of createmodel and create the layers first so that you have a reference to them outside of
A third option is to just create a hacky functor for m, like described here: Writing complex Flux Models