I see. In that case I do not recommend the solution proposed in that thread, since overloading params
as necessary would involve some pretty severe type piracy. I think the Concat
struct is actually the way to go. That post is quite old, and flux has changed somewhat (e.g. @treelike
is deprecated, to my knowledge) so the following should be the way to do it now:
struct Concat{T}
catted::T
end
Concat(xs...) = Concat(xs)
Flux.@functor Concat
# if the behavior you're looking for is to produce
# [f1(x); f2(x) ...], which is more image-processing-like
# function (C::Concat)(x)
# mapreduce(f->f(x), vcat, C.catted)
# end
# if the behavior you're looking for is to produce
# [f1(x1); f2(x2) ...], like in the original issue
function (C::Concat)(x)
mapreduce((f, x) -> f(x), vcat, C.catted, x)
end
# The example from the post:
julia> A1 = Dense(5, 5); B1 = Dense(5, 5); A2 = Dense(5, 5);
B2 = Dense(5, 5); C2 = Dense(5, 5); D = Dense(10, 5);
E = Dense(5, 5);
julia> model = Chain(Concat(Chain(A1, B1), Chain(A2, B2, C2)), D, E);
# the expected input to `Concat` with the f1(x1) form is
# now an iterable of arrays
julia> input = [rand(5), rand(5)];
julia> model(input)
5-element Array{Float32,1}:
0.71026456
-0.7175311
-0.30744648
-0.8034078
0.8368984
And params works on the Concat
as expected
julia> Flux.params(model[1])
Params([Float32[0.23016948 0.05715016 … 0.62341815 -0.5434392; -0.37009057 -0.3970632 … 0.62382203 0.41673365; … ; 0.061818454 0.16528758 … -0.2425396 0.603057; -0.260415 0.33849806 … -0.32674253 0.50213516], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.57999533 -0.34143996 … -0.14378327 0.4262758; -0.764463 -0.55587304 … -0.45619386 -0.2658353; … ; -0.45302996 0.5440691 … -0.6935754 -0.68062; -0.5139856 -0.12330929 … 0.17966792 -0.40127572], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.061877552 0.5428227 … -0.31964645 0.76445305; 0.015927391 0.21079396 … -0.6141678 -0.17098361; … ; 0.34704348 -0.64504695 … -0.67650115 0.28716063; -0.5823784 0.16855657 … 0.4026331 -0.18516172], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.42521372 0.13450116 … 0.21042147 -0.65950245; -0.5163227 0.31146872 … 0.5260253 -0.059268236; … ; -0.01898363 -0.010103932 … 0.5870273 -0.6593026; -0.23822755 -0.76355386 … -0.4103931 0.7052716], Float32[0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.48998186 0.76523256 … 0.42303228 -0.017605193; 0.09143734 0.28700328 … 0.24258374 0.46789637; … ; 0.749469 0.26842743 … 0.36736602 0.13056715; 0.44405162 -0.24533193 … -0.63523906 0.18275999], Float32[0.0, 0.0, 0.0, 0.0, 0.0]])