Hello! I’ve been pulling my hair out trying to implement an ODE where one of the parameters is the output of a NN net. ChatGPT has me going in circles.
Depending on what fixes I try I either get a:
Error in epoch 1: MethodError(+, (ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Fl…
or
SciMLSensitivity.AdjointSensitivityParameterCompatibilityError()
when calculating the gradients
I’ve tried multiple sensalgs for the ODE solver and am only passing the trainable parameters (the NN parameters). All fixed parameters are global.
Is there something obvious I’m missing? Could someone point me to a clean example of this type of problem. Code is below
using DifferentialEquations
using DiffEqFlux
using Lux
using Random
using Optimisers
using ComponentArrays
using OrdinaryDiffEq
using StaticArrays
using LinearAlgebra
using Zygote
using Plots
using Profile
using ProfileView
struct DynamicsCtx{M, P, S}
model::M
ps::P
st::S
end
function plot_trajectory(sol, N::Int)
state_size = 5
plt = plot(title="Node Trajectories", xlabel="x", ylabel="y", aspect_ratio=1)
for i in 1:N
xi = [sol[u][state_size*(i-1)+1] for u in 1:length(sol)]
yi = [sol[u][state_size*(i-1)+2] for u in 1:length(sol)]
plot!(plt, xi, yi, label="Node $i")
end
# display(plt);
return plt
end
function get_hex_neighbors(positions, spacing; tol=1e-3)
springs = Tuple{Int, Int}[]
n = length(positions)
for i in 1:n-1, j in i+1:n
dist = norm(positions[i] - positions[j])
if abs(dist - spacing) < tol
push!(springs, (i, j))
end
end
return springs
end
function generate_hex_grid(radius::Int, spring_length::Float32)
positions = []
hx = spring_length * SVector(1.0, 0.0)
hy = spring_length * SVector(0.5, sqrt(3)/2)
for q in -radius+1:radius-1
rmin = max(-radius+1, -q - radius + 1)
rmax = min(radius-1, -q + radius - 1)
for r in rmin:rmax
pos = q * hx + r * hy
push!(positions, pos)
end
end
return positions
end
# Define model struct per node
function make_model(N_nodes, N_state)
x_in = N_nodes*N_state
nθ = 2*x_in
return Chain(Dense( x_in => nθ, relu), Dense(nθ => N_nodes, tanh))
end
# System parameters - leave as is, it doesn't need to be optimized
struct MultiNodeParams{T}
N::Int
spring_list::Vector{Tuple{Int, Int}}
l0::T
L::T
k::T
bθ::T
bγ::T
Iγ::T
r::T
τ::T
noise::T
s::T
end
function rws_dynamics_multi(u, a_val, fspring, p)
x, y, θ, γ, dγ = u
cosγ = cos(γ)
sinγ = sin(γ)
l = @SVector [l0 * cosγ, l0 * sinγ, zero(eltype(u))]
l̂ = l / l0
dθ = (1 / b_theta) * (tau + r * dot(fspring, l̂))
torque = cross(l, fspring)[3]
ddγ = (1 / I_gamma) * (torque - b_gamma * dγ)
dx = (r * dθ + a_val * cos(θ)) * cosγ
dy = (r * dθ + a_val * sin(θ)) * sinγ
return @SVector [dx, dy, dθ, dγ, ddγ]
end
# Modify the multi_node_dynamics to accept u, parameters, time, and the separate ps_list
function multi_node_dynamics!(du, u, ps, t)
# Unpack context
global nn_model, nn_state
# N = params.N
state_size = 5
T = eltype(u)
positions = [@SVector [u[5i - 4], u[5i - 3], zero(T)] for i in 1:N_nodes]
net_forces = [SVector{3,T}(0, 0, 0) for _ in 1:N_nodes]
# Get current spring forces
for (i, j) in springs
pi, pj = positions[i], positions[j]
d = pj - pi
D = norm(d)
if L > 1e-10
f = k * (D - L) * (d / D)
net_forces[i] += f
net_forces[j] -= f
end
end
# Forward pass through NN to get
a_raw, _ = nn_model(u, ps, nn_state)
a_val = 2f0 * a_raw # constrain output to [-2, 2]
for i in 1:N_nodes
idx = 5*(i-1)
u_i = @SVector [u[idx+1], u[idx+2], u[idx+3], u[idx+4], u[idx+5]]
du_i = rws_dynamics_multi(u_i, a_val[i], net_forces[i], nothing)
for j in 1:5
du[idx + j] = du_i[j]
end
end
return nothing
end
const L = 0.1f0 # [m], spring resting length
const k = 43.6906f0 # [N/m], spring constant
const r = 0.02f0 # [m], wheel radius
const tau = 0.01f0 # [Nm], wheel drive torque
const l0 = 0.01f0 # [m], caster offset
const d0 = 0.0f0 # [m], node radius
const M = 0.1426f0 # [kg], total mass of node
const m = 0.026f0 # [kg], caster mass
const s = 1.f0 # [], slip factor
const b_theta = 0.01f0 # [], wheel damping factor
const I_theta = Float32(1e-6) # [kgm^2], wheel rotational inertia
const b_gamma = .2f0 # [], caster damping factor
const I_gamma = Float32(M*(l0^2)) # [kgm^2], caster rotational inertia
const b_phi = b_gamma # [], base damping factor
const I_phi = Float32(0.5*(M-m)*(d0^2)) # [kgm^2], base rotational inertia
const b_xy = 0.0035f0 # [Ns/m], base damping factor
const noise = 0.00f0
const a = 0.0f0 # activity scale
rng = Random.default_rng()
T = Base.Float32
T_end = 10.0f0
n_hex = 2
positions = generate_hex_grid(n_hex, L)
springs = get_hex_neighbors(positions, L)
N_nodes = length(positions)
N_states = 5
# Setup parameters
# # Create system parameters
# system_params = MultiNodeParams(
# N_nodes, # N
# springs, # spring_list
# l0, # l0
# L, # L
# k, # k
# b_theta, # bθ
# b_gamma, # bγ
# I_gamma, # Iγ
# r, # r
# tau, # τ
# noise, # noise
# s # s
# )
# Create neural network
const nn_model = make_model(N_nodes,N_states)
nn_params, nn_state = Lux.setup(rng, nn_model)
nn_params = ComponentArray(nn_params)
# Initial conditions
u0 = rand(Float32, N_states*N_nodes) .* 0.1f0 # small random initial conditions
tspan = (0.0f0, T_end)
function loss(p)
# Create a closure that captures model and state
println("Type of p: ", typeof(p))
println("Fieldnames of p: ", fieldnames(typeof(p)))
prob = ODEProblem(multi_node_dynamics!, u0, tspan, p)
sol = solve(prob, Tsit5(), saveat=0.1f0, sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))
target = Float32.(0.1f0 .* sin.(sol.t))
pred = Array(sol)[1, :]
return sum(abs2, pred .- target)
end
# Setup optimizer
opt = Optimisers.ADAM(0.01f0)
opt_state = Optimisers.setup(opt, nn_params)
# Training loop
for epoch in 1:10
println("Epoch $epoch")
try
println("pullback")
l, back = Zygote.pullback(loss, nn_params)
@show typeof(l)
@show eltype(l)
@show typeof(back)
@show eltype(back)
println("grad")
grad = first(back(Float32(1.0)))
@show typeof(grad)
@show eltype(grad)
try
println("update optimiser")
opt_state, nn_params = Optimisers.update!(opt_state, nn_params, grad)
println("Update successful")
catch e
println("Error in update step: ", e)
@show typeof(opt), typeof(p), typeof(grad)
@show eltype(grad)
end
println("Epoch $epoch | Loss = $l")
catch e
println("Error in epoch $epoch: ", e)
# Print stack trace for debugging
@error "Training error" exception=(e, catch_backtrace())
end
end