Adding more inputs after convolution in Flux?

So I have a large spatial vector as input from which I extract some features using convolution, but I also have some others non-spatial inputs that I’d like to concatenate to my features, is there a way to do that with Flux ? Maybe with a skip connection, but I’m not sure how to handle the different parts of the input.

x = (rand(10,1,1,1), rand(3))

mconv = Chain(
    Conv((10,1), 1=>5),
    Flux.flatten,
)
    
vcat( mconv(x[1]), x[2] ) #?

This seems to work (even with named tuples, nice).

x = (spatial = rand(10,1,1,1), nonspatial = rand(5))

mconv = Chain(
    Conv((10,1), 1=>5),
    Flux.flatten,
)
    
m = Chain(
    x -> vcat(mconv(x.spatial), x.nonspatial),
    Dense(10,1)
)

Nevermind that doesn’t work (seems to “freeze” the first layer).

For inputs having multiple elements you can define a struct and apply Flux.@functor on it to get the parameters. The issue in your 2nd solution is that Flux doesn’t know how to reach the mconv parameters.

To solve this

x = (spatial = rand(10,1,1,1), nonspatial = rand(5))

struct MConv
    model
end

Flux.@functor MConv

(mconv::MConv)(x) = vcat(mconv.model(x.spatial), x.nonspatial)

model = MConv(Chain(
    Conv((10,1), 1=>5),
    Flux.flatten,
))

m = Chain(model, Dense(10, 1))
2 Likes