Help training a NN that's used as a parameter in an ODE

Hello! I’ve been pulling my hair out trying to implement an ODE where one of the parameters is the output of a NN net. ChatGPT has me going in circles.

Depending on what fixes I try I either get a:
Error in epoch 1: MethodError(+, (ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Fl…
or
SciMLSensitivity.AdjointSensitivityParameterCompatibilityError()
when calculating the gradients

I’ve tried multiple sensalgs for the ODE solver and am only passing the trainable parameters (the NN parameters). All fixed parameters are global.

Is there something obvious I’m missing? Could someone point me to a clean example of this type of problem. Code is below

using DifferentialEquations
using DiffEqFlux
using Lux
using Random
using Optimisers
using ComponentArrays
using OrdinaryDiffEq
using StaticArrays
using LinearAlgebra
using Zygote
using Plots
using Profile
using ProfileView

struct DynamicsCtx{M, P, S}
    model::M
    ps::P
    st::S
end

function plot_trajectory(sol, N::Int)
    state_size = 5
    plt = plot(title="Node Trajectories", xlabel="x", ylabel="y", aspect_ratio=1)

    for i in 1:N
        xi = [sol[u][state_size*(i-1)+1] for u in 1:length(sol)]
        yi = [sol[u][state_size*(i-1)+2] for u in 1:length(sol)]
        plot!(plt, xi, yi, label="Node $i")
    end

    # display(plt);
    return plt
end

function get_hex_neighbors(positions, spacing; tol=1e-3)
    springs = Tuple{Int, Int}[]
    n = length(positions)
    for i in 1:n-1, j in i+1:n
        dist = norm(positions[i] - positions[j])
        if abs(dist - spacing) < tol
            push!(springs, (i, j))
        end
    end
    return springs
end

function generate_hex_grid(radius::Int, spring_length::Float32)
    positions = []
    hx = spring_length * SVector(1.0, 0.0)
    hy = spring_length * SVector(0.5, sqrt(3)/2)
    for q in -radius+1:radius-1
        rmin = max(-radius+1, -q - radius + 1)
        rmax = min(radius-1, -q + radius - 1)
        for r in rmin:rmax
            pos = q * hx + r * hy
            push!(positions, pos)
        end
    end
    return positions
end

# Define model struct per node
function make_model(N_nodes, N_state)
    x_in = N_nodes*N_state
    nθ = 2*x_in
    return Chain(Dense( x_in => nθ, relu), Dense(nθ => N_nodes, tanh))
end



# System parameters - leave as is, it doesn't need to be optimized
struct MultiNodeParams{T}
    N::Int
    spring_list::Vector{Tuple{Int, Int}}
    l0::T
    L::T
    k::T
    bθ::T
    bγ::T
    Iγ::T
    r::T
    τ::T
    noise::T
    s::T
end

function rws_dynamics_multi(u, a_val, fspring, p)
    x, y, θ, γ, dγ = u
    cosγ = cos(γ)
    sinγ = sin(γ)

    l = @SVector [l0 * cosγ, l0 * sinγ, zero(eltype(u))]
    l̂ = l / l0

    dθ = (1 / b_theta) * (tau + r * dot(fspring, l̂))
    torque = cross(l, fspring)[3]
    ddγ = (1 / I_gamma) * (torque - b_gamma * dγ)

    dx = (r * dθ + a_val * cos(θ)) * cosγ
    dy = (r * dθ + a_val * sin(θ)) * sinγ

    return @SVector [dx, dy, dθ, dγ, ddγ]
end



# Modify the multi_node_dynamics to accept u, parameters, time, and the separate ps_list
function multi_node_dynamics!(du, u, ps, t)
    # Unpack context
    global nn_model, nn_state

    
    # N = params.N
    state_size = 5
    T = eltype(u)

    positions = [@SVector [u[5i - 4], u[5i - 3], zero(T)] for i in 1:N_nodes]
    net_forces = [SVector{3,T}(0, 0, 0) for _ in 1:N_nodes]

    # Get current spring forces
    for (i, j) in springs
        pi, pj = positions[i], positions[j]
        d = pj - pi
        D = norm(d)
        if L > 1e-10
            f = k * (D - L) * (d / D)
            net_forces[i] += f
            net_forces[j] -= f
        end
    end

    # Forward pass through NN to get 
    a_raw, _ = nn_model(u, ps, nn_state)
    a_val = 2f0 * a_raw  # constrain output to [-2, 2]

    for i in 1:N_nodes
        idx = 5*(i-1)
        u_i = @SVector [u[idx+1], u[idx+2], u[idx+3], u[idx+4], u[idx+5]]
        
        du_i = rws_dynamics_multi(u_i, a_val[i], net_forces[i], nothing)
        for j in 1:5
            du[idx + j] = du_i[j]
        end
    end

    return nothing
end


const L = 0.1f0        # [m], spring resting length
const k = 43.6906f0    # [N/m], spring constant
const r = 0.02f0       # [m], wheel radius
const tau = 0.01f0      # [Nm], wheel drive torque
const l0 = 0.01f0      # [m], caster offset
const d0 = 0.0f0       # [m], node radius
const M = 0.1426f0     # [kg], total mass of node
const m = 0.026f0      # [kg], caster mass
const s = 1.f0          # [], slip factor
const b_theta = 0.01f0 # [], wheel damping factor
const I_theta = Float32(1e-6) # [kgm^2], wheel rotational inertia
const b_gamma = .2f0  # [], caster damping factor
const I_gamma = Float32(M*(l0^2)) # [kgm^2], caster rotational inertia
const b_phi = b_gamma # [], base damping factor
const I_phi = Float32(0.5*(M-m)*(d0^2)) # [kgm^2], base rotational inertia
const b_xy = 0.0035f0  # [Ns/m], base damping factor
const noise = 0.00f0
const a = 0.0f0 # activity scale

rng = Random.default_rng()
T = Base.Float32
T_end = 10.0f0
n_hex = 2
positions = generate_hex_grid(n_hex, L)
springs = get_hex_neighbors(positions, L)
N_nodes = length(positions)
N_states = 5


# Setup parameters

# # Create system parameters
# system_params = MultiNodeParams(
#     N_nodes,                         # N
#     springs,  # spring_list
#     l0,                       # l0
#     L,                       # L
#     k,                       # k
#     b_theta,                       # bθ
#     b_gamma,                       # bγ
#     I_gamma,                       # Iγ
#     r,                       # r
#     tau,                       # τ
#     noise,                      # noise
#     s                        # s
# )

# Create neural network
const nn_model = make_model(N_nodes,N_states)
nn_params, nn_state = Lux.setup(rng, nn_model)
nn_params = ComponentArray(nn_params)


# Initial conditions
u0 = rand(Float32, N_states*N_nodes) .* 0.1f0  # small random initial conditions
tspan = (0.0f0, T_end)


function loss(p)
    # Create a closure that captures model and state
    println("Type of p: ", typeof(p))
    println("Fieldnames of p: ", fieldnames(typeof(p)))
    prob = ODEProblem(multi_node_dynamics!, u0, tspan, p)

    sol = solve(prob, Tsit5(), saveat=0.1f0, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))

    target = Float32.(0.1f0 .* sin.(sol.t))
    pred = Array(sol)[1, :]

    return sum(abs2, pred .- target)
end


# Setup optimizer
opt = Optimisers.ADAM(0.01f0)
opt_state = Optimisers.setup(opt, nn_params)

# Training loop
for epoch in 1:10
    println("Epoch $epoch")
    try
        println("pullback")
        l, back = Zygote.pullback(loss, nn_params)
        @show typeof(l)
        @show eltype(l)
        @show typeof(back)
        @show eltype(back)
        println("grad")
        grad = first(back(Float32(1.0)))
        @show typeof(grad)
        @show eltype(grad)
        try
            println("update optimiser")
            opt_state, nn_params = Optimisers.update!(opt_state, nn_params, grad)

            println("Update successful")
        catch e
            println("Error in update step: ", e)
            @show typeof(opt), typeof(p), typeof(grad)
            @show eltype(grad)
        end

        println("Epoch $epoch | Loss = $l")
        
    catch e
        println("Error in epoch $epoch: ", e)
        # Print stack trace for debugging
        @error "Training error" exception=(e, catch_backtrace())
    end
end

Did you try seeing the documentation examples on this? I highly recommend following docs like Automatically Discover Missing Physics by Embedding Machine Learning into Differential Equations · Overview of Julia's SciML.

The answer is to just use a ComponentArray for the parameters and you won’t have SciMLSensitivity.AdjointSensitivityParameterCompatibilityError(): that error is saying you’re using something like a namedtuple that doesn’t have vector properties. I see in the code you posted you won’t have that issue since nn_params is the ComponentArray you want. But I’d check if opt_state, nn_params = Optimisers.update!(opt_state, nn_params, grad) keeps the right type for that.

Can you print the actual error messages and the stack traces? You aren’t sharing the helpful parts.