AD of UDEs - how to set it up?

Hi! Im new to Julia and SciML but I am really curious to see how combining ODE’s with neural network could improve the predictive performance of the model we are developing for solvent-based CO2 capture. My current setup includes

  • dynamic states,
  • multiple independent timeseries of real data
  • exogenous inputs,
  • unknown dynamics modelled as a neural network (NN).

I use DataInterpolations.jl to handle the inputs. I want to train the NN by minimising a loss function (MSE across all the independent time series).

I have made a MWE (tried to make it as minimal as possible) with 5 series of data, each with ~200 datapoints.

My initial challenge is the following: Computing the gradient of the loss function with AD (using Zygote.jl or ForwardDiff.jl) takes long time. It seems like I am missing something or doing something wrong.

using BenchmarkTools

@btime Zygote.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)

@btime ForwardDiff.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)

outputs:

7.955 s (64378604 allocations: 3.56 GiB)
173.575 ms (5457036 allocations: 734.76 MiB)

which seems like a lot for just computing the gradient..?! Below is the MWE - maybe someone can spot something that is clearly set up in a stupid/inefficient way.

The true system is (for data generation):
dx_1/dt = -\alpha * (x_1 - u_1) + \sin(\gamma * x_2)*u_2
dx_2/dt = -\beta *(x_2 - u_3)

and the system with the neural network is:

dx_1/dt = -\alpha * (x_1 - u_1) + NN(x_2,u_2)
dx_2/dt = -\beta *(x_2 - u_3)

i.e. a system which is partially known and partially unknown. x is the states, u are the exogenous inputs.

Below is the code:

Import of pkg and data generation:

# import packages
import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
import ComponentArrays
import DataInterpolations as DI
using Zygote
using ForwardDiff

# Standard Libraries
import Statistics

# External Libraries
import Lux
import StableRNGs
import StaticArrays as SA
using Plots

# --- Types defintions ---
struct Series
    t::Vector{Float64}                     # length T
    Y::Matrix{Float64}                     # (n_y, T)
    U_funcs::Vector{DI.ConstantInterpolation}
    X0::NTuple{2,Float64}                  # initial state (clA, clD)
    tspan::Tuple{Float64,Float64}
end

struct Dataset
    series::Vector{Series}
end

## Generate simulated data ##

# Number of data series 
N_series = 5

# Time vector for each series (different lengths)
tf = rand(5.0:10.0, N_series) # final time between 5 and 10 h
#dt = 30/3600 # time step of 30 sec
dt = 2/60
t = [0.0:dt:tf[i] for i in 1:N_series] # time vector for each series

# Input functions for each series
# input values:
u1 = [[1.0, 0.5, 2.5], [0.8, 0.7, 2.0], [1.2, 0.4, 3.0], [1.1, 0.6, 2.8], [0.9, 0.55, 2.2]]
u2 = [[2.0, 1.5, 3.5], [1.8, 1.7, 3.0], [2.2, 1.4, 4.0], [2.1, 1.6, 3.8], [1.9, 1.55, 3.2]]
u3 = [[3.0, 2.5, 4.5], [2.8, 2.7, 4.0], [3.2, 2.4, 5.0], [3.1, 2.6, 4.8], [2.9, 2.55, 4.2]]
# Time points of steps:
tu1 = [[0.0, tf[i]/2, 4tf[i]/5] for i in 1:N_series]
tu2 = [[0.0, tf[i]/3, 2tf[i]/3] for i in 1:N_series]
tu3 = [[0.0, tf[i]/4, 3tf[i]/4] for i in 1:N_series]

# Create DataInterpolations.jl ConstantInterpolation functions for each series
U_funcs = Vector{DI.ConstantInterpolation}[]
for i in 1:N_series
    push!(U_funcs, [
        DI.ConstantInterpolation(u1[i], tu1[i]; extrapolation = DI.ExtrapolationType.Constant),
        DI.ConstantInterpolation(u2[i], tu2[i]; extrapolation = DI.ExtrapolationType.Constant),
        DI.ConstantInterpolation(u3[i], tu3[i]; extrapolation = DI.ExtrapolationType.Constant)
    ])
end

# Create ODE problem and solve to generate data
p_true = [0.5, 0.3, 3.0] # true parameters

# True system equation
function ffun(x, p, t, u)
    
    alpha, beta, gamma = p

    dx1 = -alpha * (x[1] - u[1](t)) + sin(x[2]*gamma) * u[2](t)
    dx2 = -beta * (x[2] - u[3](t))
    dx = [dx1, dx2]

    return(dx)
end

# Initial conditions for each series
x0_series = [[2.0, 1.0], [1.5, 1.5], [2.5, 0.5], [2.2, 1.2], [1.8, 0.8]]

# Generate data for each series
X = Matrix{Float64}[]
for i in 1:N_series
    prob = ODE.ODEProblem((x, p, t) -> ffun(x, p, t, [U_funcs[i][1], U_funcs[i][2], U_funcs[i][3]]),
        x0_series[i], (0.0, tf[i]), p_true)
    sol = ODE.solve(prob, ODE.Tsit5(), saveat = t[i])
    push!(X, Array(sol))
end

# Add noise to the data in each series
rng = StableRNGs.StableRNG(1234)
Y = Matrix{Float64}[]
for i in 1:N_series
    ȳ = Statistics.mean(X[i], dims = 2)
    noise_magnitude = 1e-2
    Y_noise = X[i] .+ (noise_magnitude * ȳ) .* randn(rng, eltype(X[i]), size(X[i]))
    push!(Y, Y_noise)
end

# Build Dataset
dataset = Dataset(Series[])
for i in 1:N_series
    s = Series(t[i], Y[i], U_funcs[i], (x0_series[i][1], x0_series[i][2]), (0.0, tf[i]))
    push!(dataset.series, s)
end 

Define model and loss function and compute the gradient:

# --- Model Definition ---
# Neural Network
neural_net = Lux.Chain(
    Lux.Dense(2, 10, x -> exp.(-(x.^2))),
    Lux.Dense(10, 10, x -> exp.(-(x.^2))),
    Lux.Dense(10, 1) # linear head
)

rng = StableRNGs.StableRNG(1234)
p0, st = Lux.setup(rng, neural_net) 
const _st = st


# in-place ude_dynamics!
function ude_dynamics!(dx, x, p, t, U_funcs)

    # NN part
    input_NN = [x[2], U_funcs[2](t)]
    NN = neural_net(input_NN, p, _st)  # Lux forward pass

    dx[1] = -0.5 * (x[1] - U_funcs[1](t)) + NN[1]
    dx[2] = -0.3 * (x[2] - U_funcs[3](t))
    return nothing
end

# out-of-place ude_dynamics
function ude_dynamics(x, p, t, U_funcs)

    # NN part
    input_NN = [x[2], U_funcs[2](t)]
    NN, _ = neural_net(input_NN, p, _st)  # Lux forward pass

    dx1 = -0.5 * (x[1] - U_funcs[1](t)) + NN[1]
    dx2 = -0.3 * (x[2] - U_funcs[3](t))
    return [dx1, dx2]
end


# Build ODE problems
function make_problems(ds::Dataset, θ0)
    prob_list = ODE.ODEProblem[]
    for s in ds.series

        # in-place ude
        # f!(du, u, θ, t) = ude_dynamics!(du, u, θ, t, s.U_funcs, ctx)
        # push!(prob_list, ODE.ODEProblem(f!, collect(s.X0), s.tspan, θ0))

        # out-of-place ude
        f(u, θ, t) = ude_dynamics(u, θ, t, s.U_funcs)
        push!(prob_list, ODE.ODEProblem(f, collect(s.X0), s.tspan, θ0))
    end
    return prob_list
end

# Define loss function over all series
function loss_multi(θ, prob_list, dataset, solver, sensealg1)
    loss = 0.0
    for i in eachindex(prob_list)
        _prob = ODE.remake(prob_list[i], p = θ)
        sol = ODE.solve(_prob, solver;
            saveat = dataset.series[i].t,
            save_everystep = false, dense = false,
            abstol = 1e-6, reltol = 1e-6,
            sensealg = sensealg1)
        X̂ = Array(sol) # (2, T)
        loss += Statistics.mean(abs2, dataset.series[i].Y .- X̂)
    end
    return loss
end

# Create ODE problems for all series
prob_list = make_problems(dataset, θ0)

# Solver and sensitivity algorithm
solver = ODE.Tsit5()
sensealg1 = SMS.QuadratureAdjoint(autojacvec = SMS.ZygoteVJP())

_θ0 = ComponentArrays.ComponentVector{Float64}(θ0)

using BenchmarkTools

@btime Zygote.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)
@btime ForwardDiff.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)

This is not an answer to your question, but out of curiosity I asked claude to adapt an existing filtering tutorial to fit your system, and this is what it came up with.

In classical claude style, it’s a bit verbose in how it generates data etc., but it did at least produce something working with minimal intervention :slight_smile:


# Learning Unknown Dynamics with Neural Networks and Optimization
#
# This example demonstrates how to combine:
# - Known first-principles dynamics
# - Neural networks for unknown coupling terms
# - Extended Kalman Filter for state estimation
# - Optimization to learn neural network parameters
#
# Problem from: https://discourse.julialang.org/t/ad-of-udes-how-to-set-it-up/133461

using LowLevelParticleFilters
using Random, SeeToDee, StaticArrays, Plots, LinearAlgebra, Statistics
using LowLevelParticleFilters: SimpleMvNormal
using Lux, ComponentArrays, Optim

Random.seed!(42)

# =============================================================================
## DATA GENERATION - True System
# =============================================================================

# True system parameters
const α = 0.5f0
const β = 0.3f0
const γ = 3.0f0

# True coupling function (unknown in practice)
# This is what the neural network needs to learn: sin(γ·x₂)·u₂
true_coupling(x2, u2) = sin(γ * x2) * u2

# True system dynamics
# dx₁/dt = -α(x₁ - u₁) + sin(γ·x₂)·u₂
# dx₂/dt = -β(x₂ - u₃)
function true_dynamics(state, u, p, t)
    x1, x2 = state
    u1, u2, u3 = u[1], u[2], u[3]

    coupling = true_coupling(x2, u2)

    dx1 = -α * (x1 - u1) + coupling
    dx2 = -β * (x2 - u3)

    SA[dx1, dx2]
end

# Discretize true system
const Ts = Float32(2/60)
discrete_true_dynamics = SeeToDee.Rk4(true_dynamics, Ts)

# Generate multiple time series with different initial conditions
# Following the forum post setup exactly
function generate_data(; n_series=5)
    rng = Random.default_rng()
    Random.seed!(rng, 123)

    # Input step values for each series (3 steps per input, per series)
    u1_vals = [[1.0f0, 0.5f0, 2.5f0], [0.8f0, 0.7f0, 2.0f0], [1.2f0, 0.4f0, 3.0f0],
                [1.1f0, 0.6f0, 2.8f0], [0.9f0, 0.55f0, 2.2f0]]
    u2_vals = [[2.0f0, 1.5f0, 3.5f0], [1.8f0, 1.7f0, 3.0f0], [2.2f0, 1.4f0, 4.0f0],
                [2.1f0, 1.6f0, 3.8f0], [1.9f0, 1.55f0, 3.2f0]]
    u3_vals = [[3.0f0, 2.5f0, 4.5f0], [2.8f0, 2.7f0, 4.0f0], [3.2f0, 2.4f0, 5.0f0],
                [3.1f0, 2.6f0, 4.8f0], [2.9f0, 2.55f0, 4.2f0]]

    # Final times for each series
    tf = rand(5.0:10.0, n_series)

    all_data = []

    for series_idx in 1:n_series
        # Different initial conditions for each series
        x0 = SA[randn(rng) * 0.5f0, randn(rng) * 0.5f0]

        # Time points where inputs change (following forum post pattern)
        tu1 = [0.0f0, tf[series_idx]/2, 4*tf[series_idx]/5]
        tu2 = [0.0f0, tf[series_idx]/3, 2*tf[series_idx]/3]
        tu3 = [0.0f0, tf[series_idx]/4, 3*tf[series_idx]/4]

        # Number of time steps for this series
        n_steps = Int(ceil(tf[series_idx] / Ts))

        # Generate step input vectors
        u = Vector{SVector{3, Float32}}(undef, n_steps)
        for i in 1:n_steps
            t = (i-1) * Ts

            # Determine which step we're in for each input (piecewise constant)
            u1 = t < tu1[2] ? u1_vals[series_idx][1] : (t < tu1[3] ? u1_vals[series_idx][2] : u1_vals[series_idx][3])
            u2 = t < tu2[2] ? u2_vals[series_idx][1] : (t < tu2[3] ? u2_vals[series_idx][2] : u2_vals[series_idx][3])
            u3 = t < tu3[2] ? u3_vals[series_idx][1] : (t < tu3[3] ? u3_vals[series_idx][2] : u3_vals[series_idx][3])

            u[i] = SA[u1, u2, u3]
        end

        # Simulate true system
        x = Vector{SVector{2, Float32}}(undef, n_steps)
        x[1] = x0
        for i in 2:n_steps
            x[i] = discrete_true_dynamics(x[i-1], u[i-1], nothing, (i-2)*Ts)
        end

        # Add measurement noise
        R_meas = 0.05f0
        y = [x_i + R_meas * randn(rng, SVector{2, Float32}) for x_i in x]

        push!(all_data, (x=x, u=u, y=y, x0=x0, tf=tf[series_idx]))
    end

    return all_data
end

println("Generating data...")
data_series = generate_data(n_series=5)

# Visualize first series
let data = data_series[1]
    p1 = plot([x[1] for x in data.x], label="x₁", lw=2)
    plot!([x[2] for x in data.x], label="x₂", lw=2)
    title!("True States (Series 1)")
    ylabel!("State")

    p2 = plot([u[1] for u in data.u], label="u₁", lw=2, seriestype=:steppost)
    plot!([u[2] for u in data.u], label="u₂", lw=2, seriestype=:steppost)
    plot!([u[3] for u in data.u], label="u₃", lw=2, seriestype=:steppost)
    ylabel!("Input")
    xlabel!("Time step")
    title!("Inputs")

    plot(p1, p2, layout=(2,1), size=(800, 600)) |> display
end

# =============================================================================
## NEURAL NETWORK SETUP - Learn Unknown Coupling
# =============================================================================

# Neural network to learn coupling function sin(γ·x₂)·u₂
# Inputs: [x₂, u₂]
# Output: coupling term (scalar)

const nn_input_dim = 2   # [x₂, u₂]
const nn_output_dim = 1  # coupling scalar
const nn_hidden = 8

coupling_network = Chain(
    Dense(nn_input_dim, nn_hidden, tanh),
    Dense(nn_hidden, nn_hidden, tanh),
    Dense(nn_hidden, nn_output_dim)
)

# Initialize network parameters
rng = Random.default_rng()
Random.seed!(rng, 456)
nn_ps, nn_st = Lux.setup(rng, coupling_network)
nn_params = ComponentArray(nn_ps)

println("Neural network parameters: ", length(nn_params))

# =============================================================================
## HYBRID DYNAMICS - Known Physics + Neural Network
# =============================================================================

# Hybrid dynamics: known parameters + NN for unknown coupling
# UDE form:
# dx₁/dt = -α(x₁ - u₁) + NN(x₂, u₂)
# dx₂/dt = -β(x₂ - u₃)
function hybrid_dynamics_continuous(state, u, params, t)
    x1, x2 = state
    u1, u2, u3 = u[1], u[2], u[3]

    # Neural network predicts coupling term: sin(γ·x₂)·u₂
    nn_input = SA[x2, u2]
    coupling_pred, _ = Lux.apply(coupling_network, nn_input, params, nn_st)
    coupling = coupling_pred[1]

    # Known physics + learned coupling
    dx1 = -α * (x1 - u1) + coupling
    dx2 = -β * (x2 - u3)

    SA[dx1, dx2]
end

# Discretize hybrid dynamics using Rk4
discrete_hybrid_dynamics = SeeToDee.Heun(hybrid_dynamics_continuous, Ts)

# =============================================================================
## EXTENDED KALMAN FILTER SETUP
# =============================================================================

# System dimensions
const nx = 2  # State dimension [x₁, x₂]
const nu = 3  # Input dimension [u₁, u₂, u₃]
const ny = 2  # Output dimension (measure both states)

# Process and measurement noise covariances
R1 = SMatrix{nx, nx}(Diagonal(Float32[0.001, 0.001]))  # Process noise
R2 = SMatrix{ny, ny}(Diagonal(Float32[0.05^2, 0.05^2]))  # Measurement noise

# Linear measurement model (observe full state)
C = SA[1.0f0 0.0f0; 0.0f0 1.0f0]
measurement_model = LinearMeasurementModel(C, 0, R2; ny)

# =============================================================================
## OPTIMIZATION FRAMEWORK
# =============================================================================

# Cost function: sum of SSE across all time series
function cost_function(θ)
    T = eltype(θ)
    total_sse = zero(T)

    # Convert parameters to correct type
    θ_comp = ComponentArray(θ, getaxes(nn_params))

    # Evaluate cost on each time series
    try
        for data in data_series
            # Create EKF with current parameters
            x0_est = T.(data.x0)
            P0 = T(10.0) * T.(R1)
        
            kf = ExtendedKalmanFilter(
                discrete_hybrid_dynamics,
                measurement_model,
                R1,
                SimpleMvNormal(x0_est, P0);
                p = θ_comp,
                ny, nu, Ts
            )

            # Compute SSE for this series
            sse = LowLevelParticleFilters.sse(kf, data.u, data.y, θ_comp)
            total_sse += sse
        end
    catch
        return T(Inf)
    end

    return total_sse
end

# Initial cost evaluation
println("\nInitial cost: ", cost_function(nn_params))

# =============================================================================
## OPTIMIZATION
# =============================================================================
using Optim.LineSearches
println("\nStarting optimization...")

# Optimization options
opt_options = Optim.Options(
    show_trace = true,
    store_trace = true,
    iterations = 150,
    f_reltol = 1e-4,
)

# Run optimization
@time result = Optim.optimize(
    cost_function,
    nn_params,
    BFGS(alphaguess = LineSearches.InitialStatic(alpha=0.5), linesearch = LineSearches.HagerZhang()),
    opt_options;
    autodiff = :forward
)


# Extract optimized parameters
params_opt = ComponentArray(result.minimizer, getaxes(nn_params))

println("\n=== Optimization Results ===")
println("Converged: ", Optim.converged(result))
println("Iterations: ", Optim.iterations(result))
println("Initial cost: ", cost_function(nn_params))
println("Final cost: ", Optim.minimum(result))
println("Cost reduction: ", round((1 - Optim.minimum(result)/cost_function(nn_params))*100, digits=2), "%")

# =============================================================================
## RESULTS ANALYSIS
# =============================================================================

# Run filter with optimized parameters on first series
data_test = data_series[1]
kf_final = UnscentedKalmanFilter(
    discrete_hybrid_dynamics,
    measurement_model,
    R1,
    SimpleMvNormal(data_test.x0, R1);
    p = params_opt,
    ny=ny, nu=nu, Ts=Ts
)

sol = forward_trajectory(kf_final, data_test.u, data_test.y)

# Extract states
x_est = [sol.xt[i][1] for i in 1:length(sol.xt)]
y_est = [sol.xt[i][2] for i in 1:length(sol.xt)]
x_true = [data_test.x[i][1] for i in 1:length(data_test.x)]
y_true = [data_test.x[i][2] for i in 1:length(data_test.x)]

# Plot state estimation
p1 = plot(x_true, label="True x", lw=2, color=:blue)
plot!(x_est, label="Estimated x", lw=2, ls=:dash, color=:red)
ylabel!("State x")
title!("State Estimation - Series 1")

p2 = plot(y_true, label="True y", lw=2, color=:blue)
plot!(y_est, label="Estimated y", lw=2, ls=:dash, color=:red)
ylabel!("State y")
xlabel!("Time step")

plot(p1, p2, layout=(2,1), size=(900, 600)) |> display

# =============================================================================
## LEARNED COUPLING FUNCTION COMPARISON
# =============================================================================

# Compute learned coupling
function learned_coupling(x2, u2, params)
    nn_input = SA[x2, u2]
    output, _ = Lux.apply(coupling_network, nn_input, params, nn_st)
    return output[1]
end

# Plot 1: Coupling as function of x₂ for different u₂ values
x2_test = LinRange(-1.0f0, 3.0f0, 100)
u2_values = [1.4f0, 2.2f0, 3.5f0]

p1 = plot(title="Learned vs True Coupling: f(x₂, u₂)", xlabel="x₂", ylabel="Coupling",
          legend=:outertopright, size=(900, 500))

for (i, u2_val) in enumerate(u2_values)
    # True coupling
    coupling_true_1d = [true_coupling(x2, u2_val) for x2 in x2_test]
    plot!(x2_test, coupling_true_1d, label="True (u₂=$u2_val)", lw=2, ls=:solid, c=i)

    # Learned coupling
    coupling_learned_1d = [learned_coupling(x2, u2_val, params_opt) for x2 in x2_test]
    plot!(x2_test, coupling_learned_1d, label="Learned (u₂=$u2_val)", lw=2, ls=:dash, alpha=0.7, c=i)
end

# Plot 2: 2D heatmap comparison
x2_grid = LinRange(-1.0f0, 3.0f0, 50)
u2_grid = LinRange(1.4f0, 3.8f0, 50)

coupling_true_2d = [true_coupling(x2, u2) for u2 in u2_grid, x2 in x2_grid]
coupling_learned_2d = [learned_coupling(x2, u2, params_opt) for u2 in u2_grid, x2 in x2_grid]

p2 = heatmap(x2_grid, u2_grid, coupling_true_2d,
             title="True: sin(3·x₂)·u₂", xlabel="x₂", ylabel="u₂",
             c=:viridis, clims=(-2, 2))

p3 = heatmap(x2_grid, u2_grid, coupling_learned_2d,
             title="Learned: NN(x₂, u₂)", xlabel="x₂", ylabel="u₂",
             c=:viridis, clims=(-2, 2))

p4 = heatmap(x2_grid, u2_grid, abs.(coupling_true_2d .- coupling_learned_2d),
             title="Absolute Error", xlabel="x₂", ylabel="u₂",
             c=:hot)

plot(p1, plot(p2, p3, p4, layout=(1,3), size=(1200, 300)), layout=(2,1), size=(1200, 800)) |> display

# =============================================================================
## CONVERGENCE PLOT
# =============================================================================

# Extract cost history
cost_history = [trace.value for trace in result.trace]
iterations = 0:length(cost_history)-1

plot(iterations, cost_history, lw=2, marker=:circle, ms=3,
     xlabel="Iteration", ylabel="Cost (Total SSE)",
     title="Optimization Convergence", legend=false,
     size=(800, 500), yscale=:log10) |> display

println("\n=== Example Complete ===")
println("The neural network has learned to approximate the unknown coupling dynamics")
println("by combining Kalman filtering for state estimation with gradient-based optimization.")

Are you comparing against some baseline approach that isn’t considered slow? How much time does that approach take?

In my experience Zygote can be extremely slow: `Zygote.gradient` is 54000 TIMES slower than `jax.gradient`. I ran into this once and never used Zygote since. ForwardDiff seems to be the fastest, easiest to use and most stable. Just browsing this forum I see people having trouble with Enzyme, for example (usually crashes and unintuitive behavior), but I don’t think I’ve ever seen anyone struggle with ForwardDiff: it usually just works, and it’s often surprisingly fast.

Enzyme from within a Reactant compilation should usually just work and be quite fast/effective on ML code – including against jax, if you want to give it a go.

The Lux docs (which use it by default) should give a good intro on usage (cc @avik-pal).