where Dense(10,6) is the union of the output neurons of Dense(10,4) of m2 and Dense(10,2) of m1, with their respective input weights from their shared 10-neurons input layer.
What is the best way to construct m3?
Here’s some stuff I have kicking around from a project a long time ago. If you can figure out how to adjust it to your needs, I think all the pieces are here:
# My needs were a bit more generic than just flux, so I
# needed methods for getting e.g. the weights, etc. and
# also for dealing with other layer types, so this can be
# simplified for your use case if you're 100% flux.
"""
stack_layers(layers...; [σ = nothing], [block_diagonal_weights = false])
Vertically stack compatible FFNN layers.
- `σ` - set the activation function for the output layer. If all layers do not
have the same activation, then this field is required.
- `block_diagonal_weights` - whether to stack weights *vertically* or into a
*block diagonal*, i.e. `[W₁; W₂;...]` or `[W₁ 0 ...; 0, W₂, 0...; ...]`.
"""
function stack_layers(layers::LayerType...; σ = nothing, block_diagonal_weights = false) where LayerType
if σ == nothing
σ = activation(layers[1])
@assert all(l-> σ == activation(l), layers) "All the layers to be concatenated
must have the same activation function unless the `σ` keyword is set"
end
Ws = (weights(l) for l in layers)
W = block_diagonal_weights ? block_diagonal(Ws...) : vcat(Ws...)
b = vcat(bias(l) for l in layers)
LayerType(W, b, σ)
end
"""
stack_networks(chains::NetworkType...)
Vertically stack compatible FFNNs. Each `chain` must be an iterable/indexable object
consisting of FFNN "layers" that implement `weights`, `bias`, and `activation`.
"""
function stack_networks(chains::NetworkType...) where NetworkType
c1 = first(chains)
@assert all(length(c1) .== length.(chains)) "All of the Chains must be the same length to be `stack`ed"
C = collect(c1)
for i in 1:length(c1)
C[i] = stack_layers((chain[i] for chain in chains)..., block_diagonal_weights = (i != 1))
end
NT = Base.typename(NetworkType)
NT(C...)
end
# to work with the above, you'll also need these:
weights(D::Dense) = D.W
bias(D::Dense) = D.b
activation(D::Dense) = D.σ
"""
block_diagonal(arrays...)
Construct a fully populated block diagonal matrix from a sequence of Abstract Vector or Matrices.
# Example
julia> LI = collect(LinearIndices((3,3)))
3×3 Array{Int64,2}:
1 4 7
2 5 8
3 6 9
julia> block_diagonal(LI, LI)
6×6 Array{Int64,2}:
1 4 7 0 0 0
2 5 8 0 0 0
3 6 9 0 0 0
0 0 0 1 4 7
0 0 0 2 5 8
0 0 0 3 6 9
"""
function block_diagonal(arrays::AbstractMatrix...)
T = promote_type(eltype.(arrays)...)
A = Matrix{T}(undef, sum(size.(arrays, 1)), sum(size.(arrays, 2)))
fill!(A, zero(T <: Number ? T : Int))
n = m = 0
for B in arrays
n1, m1 = size(B) .+ (n, m)
A[(n+1):n1, (m+1):m1] .= B
n, m = n1, m1
end
A
end
block_diagonal(arrays...) = block_diagonal(_as_matrix.(arrays)...)
_as_matrix(A::AbstractVector) = reshape(A, :, 1)
_as_matrix(M::AbstractMatrix) = M
_as_matrix(x) = [x]