How to make parameters of function within Flux.Chain trainable?

Hi, I have a NeuralODE defined as follows:

dudt = Chain(x -> Dense(10, 20, tanh)(x[1]) + Dense(10, 20, tanh)(x[2]))
n_ode = NeuralODE(dudt, (0., 1.), Tsit5(), saveat=range(0., 1., length=100), reltol=1e-7, abstol=1e-9)

which takes a tuple x as input and pass its two components through two branches separately. When I check the trainable parameters with:

Flux.params(n_ode)

I got:

Params([])

it looks like parameters for two Dense layers are not detected as trainable parameters. I am wondering if there is a way to make them trainable?

Thank you!

Related to your other question, Flux layers are stateful and need to be kept around instead of re-created on every forward pass. If you want to run something through two branches without writing a custom layer/function, use Parallel:

dudt = Parallel(+, Dense(10, 20, tanh), Dense(10, 20, tanh))

For documentation on how Flux layers work and what it takes to make a model trainable, see Basics · Flux and Advanced Model Building · Flux.

Thanks for the explanation, this is really helpful!

I created a new struct for the same functionality:

struct TwoInputsLayer
    layer1
    layer2
    op     # operation to aggregate them
end

(m::TwoInputsLayer)(x) = m.op(m.layer1(x[1]), m.layer2(x[2]))
Flux.@functor TwoInputsLayer
tmp_two_inputs = TwoInputsLayer(Dense(10, 20, tanh), Dense(10, 20, tanh), +)
dudt = Chain(tmp_two_inputs)
n_ode = NeuralODE(dudt, (0., 1.), Tsit5(), saveat=range(0., 1., length=100), reltol=1e-7, abstol=1e-9)
Flux.params(n_ode)

now it seems to be able to track trainable parameters, just wonder if this definition looks good to you or there is anything else I need to be aware of when training models with this struct?

Thanks!

2 Likes