Flux: How to create a custom multi-layer model with some parameters shared across layers?

Hi all,

I’m somewhat new to Julia and Flux, and trying to train a model similar to a standard dense multi-layer neural network, but with sharing of some trainable parameters between layers.

To give a concrete example (not exactly what I want, but its close enough to illustrate the problem I’m facing in Flux):
The model has as parameters a sequence of matrices A_l and \Lambda_l, with the latter diagonal and positive-definite.
For layers l=1, ... , L-1:
x_{l+1} = \sigma (\Lambda_{l+1}^{-1} A_l \Lambda_l x_l)
and a final output layer
y = A_L \Lambda_L x_L

The main difficulty is that each matrix \Lambda_l for l=2, ..., L appears in both layer l and layer l-1. For this reason I cannot just use Chain, at least as far as I know.

What is the best way of coding this in Flux?

I have tried the code below. It works until the last line which gives the error

ERROR: Only reference types can be differentiated with Params.

I have searched for this error and no solutions I found address exactly this problem. I understand that the problem is related to having a vector of arrays in the struct defining the model. But is there a better way of representing such a structure with a flexible number of layers? Or is there a way of getting Flux to differentiate with respect to the arrays As and ds?

Grateful for any assistance!

using Flux

mutable struct Multi

function (m::Multi)(x)
    L = length(m.As)
    for l = 1:(L-1)
        Λ = diagm(exp.(m.ds[l]))
        V = diagm(exp.(-m.ds[l+1]))
        A = m.As[l]
        x = σ.( V * A * Λ * x )
    Λ = diagm(exp.(m.ds[L]))
    A = m.As[L]
    return A * Λ *x
Flux.@functor Multi

m = Multi([randn(nh,ni), randn(no,nh)],[randn(ni), randn(nh)])

x = 0
y = 1

m(1)  # check that the model evaluates

function loss(x,y)
    ŷ = m(x)
    sum((y .- ŷ)^2)

grads = gradient(() -> loss(x, y), params(m))

Chain uses a tuple, can you do the same?

struct Multi{T1<:Tuple,T2<:Tuple}

m = Multi((randn(nh,ni), randn(no,nh)), (randn(ni), randn(nh)))

AFAICT Multi doesn’t need to be mutable either.

Thanks for the suggestion, but unfortunately this gives the same error: Only reference types can be differentiated with Params .

Edit: actually it does seem to work, I was doing something silly before. Thanks for the help!