Bug at solve method for Universall Differential Equations?

I am trying to somewhat emulate the Universal Differential Equations tutorial, with a little twist. What I want to do is to learn the parameters of both a Differential Equation AND a Neural Network.

Here is the only code part I changed from the tutorial:

const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
              Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p_nn, st = Lux.setup(rng, U)
const _st = st
# Concatenate DiffEq. params with NN params
p = [rand(rng, Float32,4); p_nn]  # [α; β; γ; δ; p_nn] 

# Define the hybrid model
function ude_dynamics!(du, u, p, t)
    û = U(u, p[5], _st)[1] # Forward pass
    α, β, γ, δ = p[1:4]
    # Lokta-Volterra equations + ANN
    du[1] = α*u[1] - β*u[1]*u[2] + û[1]
    du[2] = γ*u[1]*u[2] - δ*u[2] + û[2]
end

# Define the problem
prob_nn = ODEProblem(ude_dynamics!, Xₙ[:, 1], tspan, p)

With this, I am able to evaluate predictions of the model (as p[5] == p_nn).

However, when I try to run the training with Optimization.solve I get the following error at OptimizationOptimisers.jl:59:

ERROR: MethodError: no method matching real(::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}, layer_3::@NamedTuple{…}, layer_4::@NamedTuple{…}})

It seems that solve is trying find the maximum value of the type of the parameters (Float32 in my case), using the real method in between in case the parameters are complex numers.

I tried to hack my way arround this by defining a dispatch that would return a Float32 number:

import Base: real
real(p::NamedTuple{T}) where T = 0f0

And then I got a quite informative error message telling me that I should be using a ComponentArray to define my parameter structure.

What would be the best way to do so?

The solution was actually quite simple, I just needed to define a ComponentArray with two fields:

p = ComponentArray(NN=p_nn, LV=rand(rng, Float32,4))
1 Like