How to combine two layers?

Let’s say we have trained a model m1

m1 = Chain(Dense(10, 10, relu), Dense(10,2), softmax)

and we have then substituted Dense(10,2) with Dense(10,4) and trained the latter layer only, getting a model m2

m2 = Chain(Dense(10, 10, relu), Dense(10,4), softmax)

such that m1[1] == m2[1].

Let’s say that we want now to combine them in a model m3

m3 = Chain(Dense(10, 10, relu), Dense(10,6), softmax)

where Dense(10,6) is the union of the output neurons of Dense(10,4) of m2 and Dense(10,2) of m1, with their respective input weights from their shared 10-neurons input layer.
What is the best way to construct m3?

Here’s some stuff I have kicking around from a project a long time ago. If you can figure out how to adjust it to your needs, I think all the pieces are here:

# My needs were a bit more generic than just flux, so I 
# needed methods for getting e.g. the weights, etc. and 
# also for dealing with other layer types, so this can be 
# simplified for your use case if you're 100% flux.
"""
    stack_layers(layers...; [σ = nothing], [block_diagonal_weights = false])
Vertically stack compatible FFNN layers.
 - `σ` - set the activation function for the output layer. If all layers do not
  have the same activation, then this field is required.
 - `block_diagonal_weights` - whether to stack weights *vertically* or into a
  *block diagonal*, i.e. `[W₁; W₂;...]` or `[W₁ 0 ...; 0, W₂, 0...; ...]`.
"""
function stack_layers(layers::LayerType...; σ = nothing, block_diagonal_weights = false) where LayerType
    if σ == nothing
        σ = activation(layers[1])
        @assert all(l-> σ == activation(l), layers) "All the layers to be concatenated
            must have the same activation function unless the `σ` keyword is set"
    end
    Ws = (weights(l) for l in layers)
    W = block_diagonal_weights ? block_diagonal(Ws...) : vcat(Ws...)
    b = vcat(bias(l) for l in layers)
    LayerType(W, b, σ)
end

"""
    stack_networks(chains::NetworkType...)
Vertically stack compatible FFNNs. Each `chain` must be an iterable/indexable object
consisting of FFNN "layers" that implement `weights`, `bias`, and `activation`.
"""
function stack_networks(chains::NetworkType...) where NetworkType
    c1 = first(chains)
    @assert all(length(c1) .== length.(chains)) "All of the Chains must be the same length to be `stack`ed"
    C = collect(c1)
    for i in 1:length(c1)
        C[i] = stack_layers((chain[i] for chain in chains)..., block_diagonal_weights = (i != 1))
    end
    NT = Base.typename(NetworkType)
    NT(C...)
end

# to work with the above, you'll also need these:
weights(D::Dense) = D.W
bias(D::Dense) = D.b
activation(D::Dense) = D.σ

"""
    block_diagonal(arrays...)
Construct a fully populated block diagonal matrix from a sequence of Abstract Vector or Matrices.

# Example
    julia> LI = collect(LinearIndices((3,3)))
    3×3 Array{Int64,2}:
     1  4  7
     2  5  8
     3  6  9

    julia> block_diagonal(LI, LI)
    6×6 Array{Int64,2}:
     1  4  7  0  0  0
     2  5  8  0  0  0
     3  6  9  0  0  0
     0  0  0  1  4  7
     0  0  0  2  5  8
     0  0  0  3  6  9
"""
function block_diagonal(arrays::AbstractMatrix...)
    T = promote_type(eltype.(arrays)...)
    A = Matrix{T}(undef, sum(size.(arrays, 1)), sum(size.(arrays, 2)))
    fill!(A, zero(T <: Number ? T : Int))
    n = m = 0
    for B in arrays
        n1, m1 = size(B) .+ (n, m)
        A[(n+1):n1, (m+1):m1] .= B
        n, m = n1, m1
    end
    A
end
block_diagonal(arrays...) = block_diagonal(_as_matrix.(arrays)...)
_as_matrix(A::AbstractVector) = reshape(A, :, 1)
_as_matrix(M::AbstractMatrix) = M
_as_matrix(x) = [x]

1 Like