Yes, that is the general idea, but there are a few tweaks needed.
To use Flux.train!
you need to supply the parameters you want to update. This is conveniently done using Flux.params
for all layers in flux as well as for Chain
s which only have “raw” layers in them.
Unfortunately, when layers are wrapped in functions like your m
above they are no longer accessible in this way. Another issue is that m
will create a new Conv
every time it is called, so whatever parameters the layer have will be discarded the next time the function is called.
Here is one possible way of how to work around this:
julia> function createmodel(convlayer, denselayer)
return function(x)
x2 = convlayer(x[2]) |> flatten
x1 = cat(x[1], x2, dims=1)
return denselayer(x1)
end, params(convlayer, denselayer)
end
createmodel (generic function with 1 method)
julia> m, ps = createmodel(Conv((2,2), 1=>1), Dense(371, 10));
julia> Flux.train!((x, y) -> Flux.mse(m(x), y), ps, [(((ones(10,2), rand(20,20,1,2))), ones(10, 2))], Descent())
As you can see, createmodel
returns not just the function to be optimized but also its parameters. You could also just as well remove that part of createmodel and create the layers first so that you have a reference to them outside of m
.
A third option is to just create a hacky functor for m, like described here: Writing complex Flux Models - #4 by DrChainsaw