AD of UDEs - how to set it up?

Hi! Im new to Julia and SciML but I am really curious to see how combining ODE’s with neural network could improve the predictive performance of the model we are developing for solvent-based CO2 capture. My current setup includes

  • dynamic states,
  • multiple independent timeseries of real data
  • exogenous inputs,
  • unknown dynamics modelled as a neural network (NN).

I use DataInterpolations.jl to handle the inputs. I want to train the NN by minimising a loss function (MSE across all the independent time series).

I have made a MWE (tried to make it as minimal as possible) with 5 series of data, each with ~200 datapoints.

My initial challenge is the following: Computing the gradient of the loss function with AD (using Zygote.jl or ForwardDiff.jl) takes long time. It seems like I am missing something or doing something wrong.

using BenchmarkTools

@btime Zygote.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)

@btime ForwardDiff.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)

outputs:

7.955 s (64378604 allocations: 3.56 GiB)
173.575 ms (5457036 allocations: 734.76 MiB)

which seems like a lot for just computing the gradient..?! Below is the MWE - maybe someone can spot something that is clearly set up in a stupid/inefficient way.

The true system is (for data generation):
dx_1/dt = -\alpha * (x_1 - u_1) + \sin(\gamma * x_2)*u_2
dx_2/dt = -\beta *(x_2 - u_3)

and the system with the neural network is:

dx_1/dt = -\alpha * (x_1 - u_1) + NN(x_2,u_2)
dx_2/dt = -\beta *(x_2 - u_3)

i.e. a system which is partially known and partially unknown. x is the states, u are the exogenous inputs.

Below is the code:

Import of pkg and data generation:

# import packages
import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
import ComponentArrays
import DataInterpolations as DI
using Zygote
using ForwardDiff

# Standard Libraries
import Statistics

# External Libraries
import Lux
import StableRNGs
import StaticArrays as SA
using Plots

# --- Types defintions ---
struct Series
    t::Vector{Float64}                     # length T
    Y::Matrix{Float64}                     # (n_y, T)
    U_funcs::Vector{DI.ConstantInterpolation}
    X0::NTuple{2,Float64}                  # initial state (clA, clD)
    tspan::Tuple{Float64,Float64}
end

struct Dataset
    series::Vector{Series}
end

## Generate simulated data ##

# Number of data series 
N_series = 5

# Time vector for each series (different lengths)
tf = rand(5.0:10.0, N_series) # final time between 5 and 10 h
#dt = 30/3600 # time step of 30 sec
dt = 2/60
t = [0.0:dt:tf[i] for i in 1:N_series] # time vector for each series

# Input functions for each series
# input values:
u1 = [[1.0, 0.5, 2.5], [0.8, 0.7, 2.0], [1.2, 0.4, 3.0], [1.1, 0.6, 2.8], [0.9, 0.55, 2.2]]
u2 = [[2.0, 1.5, 3.5], [1.8, 1.7, 3.0], [2.2, 1.4, 4.0], [2.1, 1.6, 3.8], [1.9, 1.55, 3.2]]
u3 = [[3.0, 2.5, 4.5], [2.8, 2.7, 4.0], [3.2, 2.4, 5.0], [3.1, 2.6, 4.8], [2.9, 2.55, 4.2]]
# Time points of steps:
tu1 = [[0.0, tf[i]/2, 4tf[i]/5] for i in 1:N_series]
tu2 = [[0.0, tf[i]/3, 2tf[i]/3] for i in 1:N_series]
tu3 = [[0.0, tf[i]/4, 3tf[i]/4] for i in 1:N_series]

# Create DataInterpolations.jl ConstantInterpolation functions for each series
U_funcs = Vector{DI.ConstantInterpolation}[]
for i in 1:N_series
    push!(U_funcs, [
        DI.ConstantInterpolation(u1[i], tu1[i]; extrapolation = DI.ExtrapolationType.Constant),
        DI.ConstantInterpolation(u2[i], tu2[i]; extrapolation = DI.ExtrapolationType.Constant),
        DI.ConstantInterpolation(u3[i], tu3[i]; extrapolation = DI.ExtrapolationType.Constant)
    ])
end

# Create ODE problem and solve to generate data
p_true = [0.5, 0.3, 3.0] # true parameters

# True system equation
function ffun(x, p, t, u)
    
    alpha, beta, gamma = p

    dx1 = -alpha * (x[1] - u[1](t)) + sin(x[2]*gamma) * u[2](t)
    dx2 = -beta * (x[2] - u[3](t))
    dx = [dx1, dx2]

    return(dx)
end

# Initial conditions for each series
x0_series = [[2.0, 1.0], [1.5, 1.5], [2.5, 0.5], [2.2, 1.2], [1.8, 0.8]]

# Generate data for each series
X = Matrix{Float64}[]
for i in 1:N_series
    prob = ODE.ODEProblem((x, p, t) -> ffun(x, p, t, [U_funcs[i][1], U_funcs[i][2], U_funcs[i][3]]),
        x0_series[i], (0.0, tf[i]), p_true)
    sol = ODE.solve(prob, ODE.Tsit5(), saveat = t[i])
    push!(X, Array(sol))
end

# Add noise to the data in each series
rng = StableRNGs.StableRNG(1234)
Y = Matrix{Float64}[]
for i in 1:N_series
    ȳ = Statistics.mean(X[i], dims = 2)
    noise_magnitude = 1e-2
    Y_noise = X[i] .+ (noise_magnitude * ȳ) .* randn(rng, eltype(X[i]), size(X[i]))
    push!(Y, Y_noise)
end

# Build Dataset
dataset = Dataset(Series[])
for i in 1:N_series
    s = Series(t[i], Y[i], U_funcs[i], (x0_series[i][1], x0_series[i][2]), (0.0, tf[i]))
    push!(dataset.series, s)
end 

Define model and loss function and compute the gradient:

# --- Model Definition ---
# Neural Network
neural_net = Lux.Chain(
    Lux.Dense(2, 10, x -> exp.(-(x.^2))),
    Lux.Dense(10, 10, x -> exp.(-(x.^2))),
    Lux.Dense(10, 1) # linear head
)

rng = StableRNGs.StableRNG(1234)
p0, st = Lux.setup(rng, neural_net) 
const _st = st


# in-place ude_dynamics!
function ude_dynamics!(dx, x, p, t, U_funcs)

    # NN part
    input_NN = [x[2], U_funcs[2](t)]
    NN = neural_net(input_NN, p, _st)  # Lux forward pass

    dx[1] = -0.5 * (x[1] - U_funcs[1](t)) + NN[1]
    dx[2] = -0.3 * (x[2] - U_funcs[3](t))
    return nothing
end

# out-of-place ude_dynamics
function ude_dynamics(x, p, t, U_funcs)

    # NN part
    input_NN = [x[2], U_funcs[2](t)]
    NN, _ = neural_net(input_NN, p, _st)  # Lux forward pass

    dx1 = -0.5 * (x[1] - U_funcs[1](t)) + NN[1]
    dx2 = -0.3 * (x[2] - U_funcs[3](t))
    return [dx1, dx2]
end


# Build ODE problems
function make_problems(ds::Dataset, θ0)
    prob_list = ODE.ODEProblem[]
    for s in ds.series

        # in-place ude
        # f!(du, u, θ, t) = ude_dynamics!(du, u, θ, t, s.U_funcs, ctx)
        # push!(prob_list, ODE.ODEProblem(f!, collect(s.X0), s.tspan, θ0))

        # out-of-place ude
        f(u, θ, t) = ude_dynamics(u, θ, t, s.U_funcs)
        push!(prob_list, ODE.ODEProblem(f, collect(s.X0), s.tspan, θ0))
    end
    return prob_list
end

# Define loss function over all series
function loss_multi(θ, prob_list, dataset, solver, sensealg1)
    loss = 0.0
    for i in eachindex(prob_list)
        _prob = ODE.remake(prob_list[i], p = θ)
        sol = ODE.solve(_prob, solver;
            saveat = dataset.series[i].t,
            save_everystep = false, dense = false,
            abstol = 1e-6, reltol = 1e-6,
            sensealg = sensealg1)
        X̂ = Array(sol) # (2, T)
        loss += Statistics.mean(abs2, dataset.series[i].Y .- X̂)
    end
    return loss
end

# Create ODE problems for all series
prob_list = make_problems(dataset, θ0)

# Solver and sensitivity algorithm
solver = ODE.Tsit5()
sensealg1 = SMS.QuadratureAdjoint(autojacvec = SMS.ZygoteVJP())

_θ0 = ComponentArrays.ComponentVector{Float64}(θ0)

using BenchmarkTools

@btime Zygote.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)
@btime ForwardDiff.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)

This is not an answer to your question, but out of curiosity I asked claude to adapt an existing filtering tutorial to fit your system, and this is what it came up with.

In classical claude style, it’s a bit verbose in how it generates data etc., but it did at least produce something working with minimal intervention :slight_smile:


# Learning Unknown Dynamics with Neural Networks and Optimization
#
# This example demonstrates how to combine:
# - Known first-principles dynamics
# - Neural networks for unknown coupling terms
# - Extended Kalman Filter for state estimation
# - Optimization to learn neural network parameters
#
# Problem from: https://discourse.julialang.org/t/ad-of-udes-how-to-set-it-up/133461

using LowLevelParticleFilters
using Random, SeeToDee, StaticArrays, Plots, LinearAlgebra, Statistics
using LowLevelParticleFilters: SimpleMvNormal
using Lux, ComponentArrays, Optim

Random.seed!(42)

# =============================================================================
## DATA GENERATION - True System
# =============================================================================

# True system parameters
const α = 0.5f0
const β = 0.3f0
const γ = 3.0f0

# True coupling function (unknown in practice)
# This is what the neural network needs to learn: sin(γ·x₂)·u₂
true_coupling(x2, u2) = sin(γ * x2) * u2

# True system dynamics
# dx₁/dt = -α(x₁ - u₁) + sin(γ·x₂)·u₂
# dx₂/dt = -β(x₂ - u₃)
function true_dynamics(state, u, p, t)
    x1, x2 = state
    u1, u2, u3 = u[1], u[2], u[3]

    coupling = true_coupling(x2, u2)

    dx1 = -α * (x1 - u1) + coupling
    dx2 = -β * (x2 - u3)

    SA[dx1, dx2]
end

# Discretize true system
const Ts = Float32(2/60)
discrete_true_dynamics = SeeToDee.Rk4(true_dynamics, Ts)

# Generate multiple time series with different initial conditions
# Following the forum post setup exactly
function generate_data(; n_series=5)
    rng = Random.default_rng()
    Random.seed!(rng, 123)

    # Input step values for each series (3 steps per input, per series)
    u1_vals = [[1.0f0, 0.5f0, 2.5f0], [0.8f0, 0.7f0, 2.0f0], [1.2f0, 0.4f0, 3.0f0],
                [1.1f0, 0.6f0, 2.8f0], [0.9f0, 0.55f0, 2.2f0]]
    u2_vals = [[2.0f0, 1.5f0, 3.5f0], [1.8f0, 1.7f0, 3.0f0], [2.2f0, 1.4f0, 4.0f0],
                [2.1f0, 1.6f0, 3.8f0], [1.9f0, 1.55f0, 3.2f0]]
    u3_vals = [[3.0f0, 2.5f0, 4.5f0], [2.8f0, 2.7f0, 4.0f0], [3.2f0, 2.4f0, 5.0f0],
                [3.1f0, 2.6f0, 4.8f0], [2.9f0, 2.55f0, 4.2f0]]

    # Final times for each series
    tf = rand(5.0:10.0, n_series)

    all_data = []

    for series_idx in 1:n_series
        # Different initial conditions for each series
        x0 = SA[randn(rng) * 0.5f0, randn(rng) * 0.5f0]

        # Time points where inputs change (following forum post pattern)
        tu1 = [0.0f0, tf[series_idx]/2, 4*tf[series_idx]/5]
        tu2 = [0.0f0, tf[series_idx]/3, 2*tf[series_idx]/3]
        tu3 = [0.0f0, tf[series_idx]/4, 3*tf[series_idx]/4]

        # Number of time steps for this series
        n_steps = Int(ceil(tf[series_idx] / Ts))

        # Generate step input vectors
        u = Vector{SVector{3, Float32}}(undef, n_steps)
        for i in 1:n_steps
            t = (i-1) * Ts

            # Determine which step we're in for each input (piecewise constant)
            u1 = t < tu1[2] ? u1_vals[series_idx][1] : (t < tu1[3] ? u1_vals[series_idx][2] : u1_vals[series_idx][3])
            u2 = t < tu2[2] ? u2_vals[series_idx][1] : (t < tu2[3] ? u2_vals[series_idx][2] : u2_vals[series_idx][3])
            u3 = t < tu3[2] ? u3_vals[series_idx][1] : (t < tu3[3] ? u3_vals[series_idx][2] : u3_vals[series_idx][3])

            u[i] = SA[u1, u2, u3]
        end

        # Simulate true system
        x = Vector{SVector{2, Float32}}(undef, n_steps)
        x[1] = x0
        for i in 2:n_steps
            x[i] = discrete_true_dynamics(x[i-1], u[i-1], nothing, (i-2)*Ts)
        end

        # Add measurement noise
        R_meas = 0.05f0
        y = [x_i + R_meas * randn(rng, SVector{2, Float32}) for x_i in x]

        push!(all_data, (x=x, u=u, y=y, x0=x0, tf=tf[series_idx]))
    end

    return all_data
end

println("Generating data...")
data_series = generate_data(n_series=5)

# Visualize first series
let data = data_series[1]
    p1 = plot([x[1] for x in data.x], label="x₁", lw=2)
    plot!([x[2] for x in data.x], label="x₂", lw=2)
    title!("True States (Series 1)")
    ylabel!("State")

    p2 = plot([u[1] for u in data.u], label="u₁", lw=2, seriestype=:steppost)
    plot!([u[2] for u in data.u], label="u₂", lw=2, seriestype=:steppost)
    plot!([u[3] for u in data.u], label="u₃", lw=2, seriestype=:steppost)
    ylabel!("Input")
    xlabel!("Time step")
    title!("Inputs")

    plot(p1, p2, layout=(2,1), size=(800, 600)) |> display
end

# =============================================================================
## NEURAL NETWORK SETUP - Learn Unknown Coupling
# =============================================================================

# Neural network to learn coupling function sin(γ·x₂)·u₂
# Inputs: [x₂, u₂]
# Output: coupling term (scalar)

const nn_input_dim = 2   # [x₂, u₂]
const nn_output_dim = 1  # coupling scalar
const nn_hidden = 8

coupling_network = Chain(
    Dense(nn_input_dim, nn_hidden, tanh),
    Dense(nn_hidden, nn_hidden, tanh),
    Dense(nn_hidden, nn_output_dim)
)

# Initialize network parameters
rng = Random.default_rng()
Random.seed!(rng, 456)
nn_ps, nn_st = Lux.setup(rng, coupling_network)
nn_params = ComponentArray(nn_ps)

println("Neural network parameters: ", length(nn_params))

# =============================================================================
## HYBRID DYNAMICS - Known Physics + Neural Network
# =============================================================================

# Hybrid dynamics: known parameters + NN for unknown coupling
# UDE form:
# dx₁/dt = -α(x₁ - u₁) + NN(x₂, u₂)
# dx₂/dt = -β(x₂ - u₃)
function hybrid_dynamics_continuous(state, u, params, t)
    x1, x2 = state
    u1, u2, u3 = u[1], u[2], u[3]

    # Neural network predicts coupling term: sin(γ·x₂)·u₂
    nn_input = SA[x2, u2]
    coupling_pred, _ = Lux.apply(coupling_network, nn_input, params, nn_st)
    coupling = coupling_pred[1]

    # Known physics + learned coupling
    dx1 = -α * (x1 - u1) + coupling
    dx2 = -β * (x2 - u3)

    SA[dx1, dx2]
end

# Discretize hybrid dynamics using Rk4
discrete_hybrid_dynamics = SeeToDee.Heun(hybrid_dynamics_continuous, Ts)

# =============================================================================
## EXTENDED KALMAN FILTER SETUP
# =============================================================================

# System dimensions
const nx = 2  # State dimension [x₁, x₂]
const nu = 3  # Input dimension [u₁, u₂, u₃]
const ny = 2  # Output dimension (measure both states)

# Process and measurement noise covariances
R1 = SMatrix{nx, nx}(Diagonal(Float32[0.001, 0.001]))  # Process noise
R2 = SMatrix{ny, ny}(Diagonal(Float32[0.05^2, 0.05^2]))  # Measurement noise

# Linear measurement model (observe full state)
C = SA[1.0f0 0.0f0; 0.0f0 1.0f0]
measurement_model = LinearMeasurementModel(C, 0, R2; ny)

# =============================================================================
## OPTIMIZATION FRAMEWORK
# =============================================================================

# Cost function: sum of SSE across all time series
function cost_function(θ)
    T = eltype(θ)
    total_sse = zero(T)

    # Convert parameters to correct type
    θ_comp = ComponentArray(θ, getaxes(nn_params))

    # Evaluate cost on each time series
    try
        for data in data_series
            # Create EKF with current parameters
            x0_est = T.(data.x0)
            P0 = T(10.0) * T.(R1)
        
            kf = ExtendedKalmanFilter(
                discrete_hybrid_dynamics,
                measurement_model,
                R1,
                SimpleMvNormal(x0_est, P0);
                p = θ_comp,
                ny, nu, Ts
            )

            # Compute SSE for this series
            sse = LowLevelParticleFilters.sse(kf, data.u, data.y, θ_comp)
            total_sse += sse
        end
    catch
        return T(Inf)
    end

    return total_sse
end

# Initial cost evaluation
println("\nInitial cost: ", cost_function(nn_params))

# =============================================================================
## OPTIMIZATION
# =============================================================================
using Optim.LineSearches
println("\nStarting optimization...")

# Optimization options
opt_options = Optim.Options(
    show_trace = true,
    store_trace = true,
    iterations = 150,
    f_reltol = 1e-4,
)

# Run optimization
@time result = Optim.optimize(
    cost_function,
    nn_params,
    BFGS(alphaguess = LineSearches.InitialStatic(alpha=0.5), linesearch = LineSearches.HagerZhang()),
    opt_options;
    autodiff = :forward
)


# Extract optimized parameters
params_opt = ComponentArray(result.minimizer, getaxes(nn_params))

println("\n=== Optimization Results ===")
println("Converged: ", Optim.converged(result))
println("Iterations: ", Optim.iterations(result))
println("Initial cost: ", cost_function(nn_params))
println("Final cost: ", Optim.minimum(result))
println("Cost reduction: ", round((1 - Optim.minimum(result)/cost_function(nn_params))*100, digits=2), "%")

# =============================================================================
## RESULTS ANALYSIS
# =============================================================================

# Run filter with optimized parameters on first series
data_test = data_series[1]
kf_final = UnscentedKalmanFilter(
    discrete_hybrid_dynamics,
    measurement_model,
    R1,
    SimpleMvNormal(data_test.x0, R1);
    p = params_opt,
    ny=ny, nu=nu, Ts=Ts
)

sol = forward_trajectory(kf_final, data_test.u, data_test.y)

# Extract states
x_est = [sol.xt[i][1] for i in 1:length(sol.xt)]
y_est = [sol.xt[i][2] for i in 1:length(sol.xt)]
x_true = [data_test.x[i][1] for i in 1:length(data_test.x)]
y_true = [data_test.x[i][2] for i in 1:length(data_test.x)]

# Plot state estimation
p1 = plot(x_true, label="True x", lw=2, color=:blue)
plot!(x_est, label="Estimated x", lw=2, ls=:dash, color=:red)
ylabel!("State x")
title!("State Estimation - Series 1")

p2 = plot(y_true, label="True y", lw=2, color=:blue)
plot!(y_est, label="Estimated y", lw=2, ls=:dash, color=:red)
ylabel!("State y")
xlabel!("Time step")

plot(p1, p2, layout=(2,1), size=(900, 600)) |> display

# =============================================================================
## LEARNED COUPLING FUNCTION COMPARISON
# =============================================================================

# Compute learned coupling
function learned_coupling(x2, u2, params)
    nn_input = SA[x2, u2]
    output, _ = Lux.apply(coupling_network, nn_input, params, nn_st)
    return output[1]
end

# Plot 1: Coupling as function of x₂ for different u₂ values
x2_test = LinRange(-1.0f0, 3.0f0, 100)
u2_values = [1.4f0, 2.2f0, 3.5f0]

p1 = plot(title="Learned vs True Coupling: f(x₂, u₂)", xlabel="x₂", ylabel="Coupling",
          legend=:outertopright, size=(900, 500))

for (i, u2_val) in enumerate(u2_values)
    # True coupling
    coupling_true_1d = [true_coupling(x2, u2_val) for x2 in x2_test]
    plot!(x2_test, coupling_true_1d, label="True (u₂=$u2_val)", lw=2, ls=:solid, c=i)

    # Learned coupling
    coupling_learned_1d = [learned_coupling(x2, u2_val, params_opt) for x2 in x2_test]
    plot!(x2_test, coupling_learned_1d, label="Learned (u₂=$u2_val)", lw=2, ls=:dash, alpha=0.7, c=i)
end

# Plot 2: 2D heatmap comparison
x2_grid = LinRange(-1.0f0, 3.0f0, 50)
u2_grid = LinRange(1.4f0, 3.8f0, 50)

coupling_true_2d = [true_coupling(x2, u2) for u2 in u2_grid, x2 in x2_grid]
coupling_learned_2d = [learned_coupling(x2, u2, params_opt) for u2 in u2_grid, x2 in x2_grid]

p2 = heatmap(x2_grid, u2_grid, coupling_true_2d,
             title="True: sin(3·x₂)·u₂", xlabel="x₂", ylabel="u₂",
             c=:viridis, clims=(-2, 2))

p3 = heatmap(x2_grid, u2_grid, coupling_learned_2d,
             title="Learned: NN(x₂, u₂)", xlabel="x₂", ylabel="u₂",
             c=:viridis, clims=(-2, 2))

p4 = heatmap(x2_grid, u2_grid, abs.(coupling_true_2d .- coupling_learned_2d),
             title="Absolute Error", xlabel="x₂", ylabel="u₂",
             c=:hot)

plot(p1, plot(p2, p3, p4, layout=(1,3), size=(1200, 300)), layout=(2,1), size=(1200, 800)) |> display

# =============================================================================
## CONVERGENCE PLOT
# =============================================================================

# Extract cost history
cost_history = [trace.value for trace in result.trace]
iterations = 0:length(cost_history)-1

plot(iterations, cost_history, lw=2, marker=:circle, ms=3,
     xlabel="Iteration", ylabel="Cost (Total SSE)",
     title="Optimization Convergence", legend=false,
     size=(800, 500), yscale=:log10) |> display

println("\n=== Example Complete ===")
println("The neural network has learned to approximate the unknown coupling dynamics")
println("by combining Kalman filtering for state estimation with gradient-based optimization.")
1 Like

Are you comparing against some baseline approach that isn’t considered slow? How much time does that approach take?

In my experience Zygote can be extremely slow: `Zygote.gradient` is 54000 TIMES slower than `jax.gradient`. I ran into this once and never used Zygote since. ForwardDiff seems to be the fastest, easiest to use and most stable. Just browsing this forum I see people having trouble with Enzyme, for example (usually crashes and unintuitive behavior), but I don’t think I’ve ever seen anyone struggle with ForwardDiff: it usually just works, and it’s often surprisingly fast.

EDIT recent examples of people having issues with Enzyme:

1 Like

Enzyme from within a Reactant compilation should usually just work and be quite fast/effective on ML code – including against jax, if you want to give it a go.

The Lux docs (which use it by default) should give a good intro on usage (cc @avik-pal).

@Emil_Martinsen I’ve been working on a similar type of problem as you, though it is very stiff, and I’ve had some success using Lux’s ToSimpleChainsAdaptor() to speed up the gradient calculation since it’s much faster on CPU and uses a lot less memory.

ForwardDiff.jl is also much faster for me than Zygote.jl for gradients, and I have been unable to get the other autodiff packages to work at all with my code. My loss calculation is ~40ms, and the gradient is about 100x slower and 50x more memory expensive with ForwardDiff, and Zygote is even worse.

Another thing that can help is if you can multithread your loss function to run each simulation separately. This has caused problems for me with Zygote, but ForwardDiff has been okay with it. It seems like you can probably modify your loss function to accommodate that if your actual problem is similar to your MWE.

I’ve noticed that Lux might be dropping support for ForwardDiff though,

which is concerning to me personally for my code, but it also seems like there’s been a significant push to getting people to use Reactant.jl with Lux almost exclusively, which seems problematic for less traditional machine learning applications.

For everyone else, is there a good way to include a Reactant compiled NN model inside of ODEs or other larger models, as in the case of the top level post here? It seems like you’d have to be very careful about how you write your models, and packages like DataInterpolations.jl may just be totally incompatible.

Thanks!

1 Like

Obligatory warning: Forward-mode AD doesn’t scale to large numbers (\gg 100) of parameters, because its cost scales proportional to the number of parameters (unlike reverse-mode AD).

(Once we get far enough away from typical ML building blocks, and more into scientific modeling, I invariably find that we have to implement custom adjoint methods / vJps / rrules for at least some components to get efficient reverse-mode derivatives. YMMV. Though in the specific case of ODEs, SciMLSensitivity.jl is supposed to do most of the work for you. The other advantage of doing things more manually is that it is easier to diagnose exactly where the computational time is spent.)

1 Like

Do either of you have error messages from using Enzyme on your code with or without reactant?

I can try to take a look and help.

Tangentially, when you say “which seems problematic for less traditional machine learning applications” I don’t quite follow. Reactant should work on generic code – not just ML?

which seems problematic for less traditional machine learning applications.

+1 to @wsmoses’s comment on reactant should work on generic code.

Regarding that particular issue, the change in that case would be to use Enzyme.Forward instead of ForwardDiff (and not drop in reverse mode which might not work well for your case).

I’ve noticed that Lux might be dropping support for ForwardDiff though

Lux isn’t dropping ForwardDiff support, rather dropping a direct dependency on it. So you should be able to ]add ForwardDiff and continue using it with Lux (the breaking part in that is the dispatches on AutoForwardDiff inside Lux will need a using ForwardDiff before those functionalities are available)

1 Like

@baggepinnen That is really cool, good job Claude. An interesting approach, it solves the MWE nicely. Out of curiosity, what is the LowLevelParticleFilters.sse() function doing? I cannot find it in the documentation: API · LowLevelParticleFilters Documentation

The tutorial you linked to SciML: Adaptive Universal Differential Equation · LowLevelParticleFilters Documentation seems to have another approach, where the NN parameters are augemented into the state vector and estimated by computing the forward trajectory of the states (i.e. states+parameters). When would you use this approach compared to using the approach Claude suggests to solve the MWE in this post?

Thanks for all the replies.

@ForceBru I have also compared with a finite difference approach, the timings for the gradient computations of the loss function are:

FiniteDiff:
  3.136 s (136024065 allocations: 4.48 GiB)
ForwardDiff:
  172.853 ms (5451138 allocations: 733.88 MiB)
Zygote:
  7.980 s (64072531 allocations: 3.54 GiB)

So FiniteDiff is taking up most memory (as I would expect) but Zygote is still the slowest. Based on the replies it seems like others are also experiencing Zygote being slow..

Did you think of other baseline approaches to use?

I should have been more specific, I asked claude to combine the following two tutorials

One adds the parameters of the NN to the state to be estimated online, the other uses an optimizer to find the parameters. I typically only consider adding the parameters to the state if I expect them to be time varying or I want to detect abrupt changes etc. I have a summary table here
Which-method-should-I-use

and a video about some of this

sse = “sum of squared errors”, it performs filtering and sums up the squares of the prediction errors e the filter makes along the trajectory
e(t|t-1) = y - ŷ(t|t-1)
that is, the difference between the measured output y(t) and the predicted output given information up to and including the previous time point only, ŷ(t|t-1).

Not really, I was just wondering what your baseline was. BTW, in your timings Zygote is about 3 times slower than FiniteDiff, which AFAIK is the most basic approach to numerical differentiation…

Really cool. We also work quite a lot with Kalman filtering, maximum likelihood estimation and stochastic differential equations in our section (https://www.compute.dtu.dk/sections/dynsys) but not so much with the embedding of neural networks - which I am curious about. so it is good to find some good resources elsewhere I can learn from.

Another Q: I see you only compare the final 1-step model predictions (i.e., the filtered solution) with the data. Wouldnt it be more accurate to compare the non-filtered full-horizon model predictions with the data? Of course it depends on what you want to use the model for, but, e.g., with the sunshine disturbance model, we would like to predict maybe 6-12 hours into the future (here we cannot compute a filtered solution since we dont have information of the unknown future)

If the intention is to use the model for prediction, for instance for MPC, then certainly yes. There’s an ongoing discussion in this thread regarding different approaches to optimizing multi-step prediction performance rather than single-step performance.

In that tutorial in particular, we are interested in learning an unknown function, and since we have access to the ground truth of the unknown function the main comparison made is against this.

The summary of that discussion is that you typically want to perform filtering, and at each point in time start a multi-step prediction from the current filter estimate. Unless the prediction horizon you care about is super long (such that the information gained from using measurements up until the current time point does not matter for most of the prediction), using a multi-step prediction rather than pure simulation tend to be favorable.

1 Like