So I have a large spatial vector as input from which I extract some features using convolution, but I also have some others non-spatial inputs that I’d like to concatenate to my features, is there a way to do that with Flux ? Maybe with a skip connection, but I’m not sure how to handle the different parts of the input.
x = (rand(10,1,1,1), rand(3))
mconv = Chain(
Conv((10,1), 1=>5),
Flux.flatten,
)
vcat( mconv(x[1]), x[2] ) #?
This seems to work (even with named tuples, nice).
x = (spatial = rand(10,1,1,1), nonspatial = rand(5))
mconv = Chain(
Conv((10,1), 1=>5),
Flux.flatten,
)
m = Chain(
x -> vcat(mconv(x.spatial), x.nonspatial),
Dense(10,1)
)
Nevermind that doesn’t work (seems to “freeze” the first layer).
For inputs having multiple elements you can define a struct and apply Flux.@functor on it to get the parameters. The issue in your 2nd solution is that Flux doesn’t know how to reach the mconv parameters.
To solve this
x = (spatial = rand(10,1,1,1), nonspatial = rand(5))
struct MConv
model
end
Flux.@functor MConv
(mconv::MConv)(x) = vcat(mconv.model(x.spatial), x.nonspatial)
model = MConv(Chain(
Conv((10,1), 1=>5),
Flux.flatten,
))
m = Chain(model, Dense(10, 1))
2 Likes