Hi guys!
Ive been breaking my neck trying to get a julia alternative to this jax neural ode with exogeneous input script working but without luck!
The julia version is painfully slow whereas the jax version runs in a couple of minutes!
ONE MAJOR DISTINCTION: The jax version solves the ode in a dense fasion and uses an interpolation of these points at tvec for the loss! So it makes sense the julia version is slower, but i simply couldnt get the dense=true version without saveat=tvec to work with any ad backend with or without sensealgs! ![]()
Heres the MWE (or at least an attempt) of the julia version :
using Lux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, SciMLSensitivity, SciMLBase
using ComponentArrays, Random, Statistics, MLUtils
#constants
const M = 1.0f0
const F_N = 5.0f0
const ZETA = 0.1f0
const OMEGA_N = 2f0 * Float32(pi) * F_N
const C = 2f0 * ZETA * M * OMEGA_N
const K0 = M * OMEGA_N^2
const K_CUBIC = 1000.0f0
# Sampling
const FS = 96_000f0
const dt = 1f0 / FS
const T1 = 1.0f0
const tspan = (0.0f0, T1)
const tvec = collect(tspan[1]:dt:tspan[2])
const N_T = length(tvec)
const DT_INV = 1f0 / dt
# Data generation
const BATCH_SIZE = 16
const N_SAMPLES = BATCH_SIZE
Random.seed!(0)
const U_data = randn(Float32, N_T, N_SAMPLES)
function duffing!(du, u, p, t)
idx = clamp(round(Int, t * DT_INV) + 1, 1, N_T)
x, v = u[1], u[2]
k_eff = K0 * (1f0 + K_CUBIC * x^2)
du[1] = v
du[2] = (U_data[idx, Int(p)] - C * v - k_eff * x) / M
nothing
end
const Y_data = Float32.(cat([Array(solve(ODEProblem(duffing!, zeros(Float32, 2), tspan, i), Tsit5(); saveat=tvec, reltol=1e-6, abstol=1e-6)) for i in 1:N_SAMPLES]..., dims=3))
# Normalize targets per trajectory and state (improves training stability)
for i in 1:N_SAMPLES, s in 1:2
σ = std(@view Y_data[:, s, i])
σ > 0 ? (Y_data[:, s, i] ./= σ) : nothing
end
const U_list = [@view U_data[:, i] for i in 1:N_SAMPLES]
const Y_list = [@view Y_data[:, :, i] for i in 1:N_SAMPLES]
const loader = DataLoader((U_list, Y_list); batchsize=BATCH_SIZE, shuffle=true)
# Model (same width/depth as JAX version)
model = Chain(Dense(3 => 64, tanh), Dense(64 => 64, tanh), Dense(64 => 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray{Float32}(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)
const CURRENT_U = Ref{Vector{Float32}}(zeros(Float32, N_T))
function dudt!(du, u, p, t)
@inbounds @fastmath begin
t_idx = clamp(round(Int, t * DT_INV) + 1, 1, N_T)
input_vec = [u[1], u[2], CURRENT_U[][t_idx]]
out = smodel(input_vec, p)
du[1] = out[1]
du[2] = out[2]
end
end
# Training setup
const prob = ODEProblem(dudt!, zeros(Float32, 2), Float32.(tspan), ps_ca)
function predict(p, u_vec)
CURRENT_U[] = u_vec
Array(solve(remake(prob, p=p), Tsit5(); dense=true, sensealg=ReverseDiffAdjoint()).(Float32.(tvec)))
end
const mse_loss = Lux.MSELoss()
function loss(p, (u_list, y_list))
ŷ_list = predict.(Ref(p), u_list)
mean(mse_loss.(ŷ_list, y_list))
end
println("Training minimal Neural ODE (Duffing, fs=96 kHz, T=1s)...")
optf = OptimizationFunction(loss, Optimization.AutoReverseDiff(compile=true))
Optimization.solve(
OptimizationProblem(optf, ps_ca, loader),
Optimisers.Adam(1f-3);
epochs=50,
callback=(s, l) -> (println("L: $l"); false),
)
and here is the jax version
import time
import diffrax
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax
# ============================================================================
# Configuration (keep small, run fast)
# ============================================================================
N_STATES = 2 # Duffing: [x, v]
FS = 96_000.0 # Hz
DT = 1.0 / FS
T1 = 1.0
BATCH_SIZE = 16
SEED = 0
F_LOW = 0.5
F_HIGH = 200.0
U_SCALE = 10.0
HIDDEN_SIZE = 64
DEPTH = 2
N_EPOCHS = 50
LR = 1e-3
# Duffing parameters (chosen to be nonlinear but not stiff at this dt)
M = 1.0
F_N = 5.0 # Hz
ZETA = 0.1
OMEGA_N = 2 * np.pi * F_N
C = 2 * ZETA * M * OMEGA_N
K0 = M * OMEGA_N**2
K_CUBIC = 1_000.0
# ============================================================================
# Signals (JAX FFT bandpass; batch-friendly)
# ============================================================================
def make_ts(t1: float = T1) -> jnp.ndarray:
n_time = int(t1 / DT) + 1
return jnp.arange(n_time, dtype=jnp.float32) * jnp.array(DT, dtype=jnp.float32)
def make_forcing_batch(key: jax.Array, *, n_batch: int, ts: jnp.ndarray) -> jnp.ndarray:
"""Generate white noise excitation u(t) with per-trajectory normalization."""
u = jr.normal(key, (n_batch, int(ts.shape[0])), dtype=jnp.float32)
u = u / jnp.std(u, axis=-1, keepdims=True)
return u * jnp.array(U_SCALE, dtype=jnp.float32)
# ============================================================================
# Duffing ground truth (stacked batched solve + dense output)
# ============================================================================
def duffing_vector_field(t, state, u_interp):
"""Forced Duffing oscillator, supports (n_states,) and (batch, n_states)."""
u = u_interp.evaluate(t)
x = state[..., 0]
v = state[..., 1]
k_eff = K0 * (1.0 + K_CUBIC * x**2)
dxdt = v
dvdt = (u - C * v - k_eff * x) / M
return jnp.stack([dxdt, dvdt], axis=-1)
@eqx.filter_jit
def simulate_duffing_batch(ts: jnp.ndarray, y0_batch: jnp.ndarray, u_coeffs_batch: tuple) -> jnp.ndarray:
u_interp = diffrax.CubicInterpolation(ts, u_coeffs_batch)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(duffing_vector_field),
diffrax.Tsit5(),
ts[0],
ts[-1],
dt0=None,
y0=y0_batch,
args=u_interp,
saveat=diffrax.SaveAt(dense=True),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-6),
max_steps=1_000_000,
)
ys = jax.vmap(sol.evaluate)(ts) # (time, batch, state)
return jnp.transpose(ys, (1, 2, 0)) # (batch, state, time)
# ============================================================================
# Neural ODE (dense output, stacked batch solve)
# ============================================================================
def _time_first_coeffs(ts: jnp.ndarray, coeffs: tuple) -> tuple:
nseg = int(ts.shape[0] - 1)
def _time_first(c: jnp.ndarray) -> jnp.ndarray:
if c.shape[0] == nseg:
return c
if c.ndim >= 2 and c.shape[1] == nseg:
return jnp.swapaxes(c, 0, 1)
raise ValueError(
"CubicInterpolation coeffs must have leading dimension (times-1) "
"or second dimension (times-1) for batched coeffs."
)
return jax.tree.map(_time_first, coeffs)
class NeuralODEFunc(eqx.Module):
"""MLP vector field: [state, u] -> dstate/dt. Supports stacked (batch, state)."""
mlp: eqx.nn.MLP
def __init__(self, *, n_states: int, hidden_size: int, depth: int, key: jax.Array):
self.mlp = eqx.nn.MLP(
in_size=n_states + 1,
out_size=n_states,
width_size=hidden_size,
depth=depth,
activation=jax.nn.gelu,
final_activation=jax.nn.identity,
key=key,
)
def __call__(self, t, state, u_interp):
u = u_interp.evaluate(t)
if jnp.ndim(state) == 1:
xu = jnp.concatenate([state, jnp.atleast_1d(u)])
return self.mlp(xu)
u = jnp.asarray(u)
xu = jnp.concatenate([state, u[..., None]], axis=-1) # (batch, in)
return jax.vmap(self.mlp)(xu)
class NeuralODE(eqx.Module):
func: NeuralODEFunc
n_states: int = eqx.field(static=True)
rtol: float = eqx.field(static=True)
atol: float = eqx.field(static=True)
max_steps: int = eqx.field(static=True)
def __init__(
self,
*,
n_states: int,
hidden_size: int,
depth: int,
rtol: float,
atol: float,
max_steps: int,
key: jax.Array,
):
self.n_states = int(n_states)
self.rtol = float(rtol)
self.atol = float(atol)
self.max_steps = int(max_steps)
self.func = NeuralODEFunc(n_states=n_states, hidden_size=hidden_size, depth=depth, key=key)
def batch_call(self, ts: jnp.ndarray, y0_batch: jnp.ndarray, u_coeffs_batch: tuple) -> jnp.ndarray:
coeffs_time_first = _time_first_coeffs(ts, u_coeffs_batch)
u_interp = diffrax.CubicInterpolation(ts, coeffs_time_first)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
ts[0],
ts[-1],
dt0=None,
y0=y0_batch,
args=u_interp,
saveat=diffrax.SaveAt(dense=True),
stepsize_controller=diffrax.PIDController(rtol=self.rtol, atol=self.atol),
max_steps=self.max_steps,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)
ys = jax.vmap(sol.evaluate)(ts) # (time, batch, state)
return jnp.transpose(ys, (1, 2, 0)) # (batch, state, time)
# ============================================================================
# Training
# ============================================================================
@eqx.filter_jit
def predict(model: NeuralODE, ts: jnp.ndarray, u_coeffs_batch: tuple, y0_batch: jnp.ndarray) -> jnp.ndarray:
return model.batch_call(ts, y0_batch, u_coeffs_batch)
@eqx.filter_jit
@eqx.filter_value_and_grad
def loss_and_grad(
model: NeuralODE, ts: jnp.ndarray, u_coeffs_batch: tuple, y_target: jnp.ndarray
) -> jnp.ndarray:
y0_batch = y_target[:, :, 0]
y_pred = predict(model, ts, u_coeffs_batch, y0_batch)
return jnp.mean((y_pred - y_target) ** 2)
@eqx.filter_jit
def train_step(
model: NeuralODE,
opt_state,
ts: jnp.ndarray,
u_coeffs_batch: tuple,
y_target: jnp.ndarray,
optimizer: optax.GradientTransformation,
):
loss, grads = loss_and_grad(model, ts, u_coeffs_batch, y_target)
updates, opt_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
# ============================================================================
# Benchmark
# ============================================================================
def run_mwe(*, n_warmup: int = 2, n_runs: int = 5) -> dict[str, float]:
devices = jax.devices()
gpu_devices = [d for d in devices if d.platform != "cpu"]
ts = make_ts()
n_time = int(ts.shape[0])
print("\n" + "=" * 72)
print("Batched Neural ODE MWE (Duffing) — JAX/Diffrax")
print("=" * 72)
print(f"Batch size: {BATCH_SIZE}")
print(f"Duration: {T1:.3f}s")
print(f"Sampling Frequency: {FS:.1f} Hz")
print(f"Time points: {n_time}")
print(f"Backend: {jax.default_backend()}")
print(f"Device: {gpu_devices[0] if gpu_devices else devices[0]}")
print("=" * 72)
key = jr.PRNGKey(SEED)
key, data_key, model_key = jr.split(key, 3)
# Data (batch of trajectories)
u_batch = make_forcing_batch(data_key, n_batch=BATCH_SIZE, ts=ts) # (B, T)
u_coeffs_batch = diffrax.backward_hermite_coefficients(ts, u_batch.T) # time-first
y0_batch = jnp.zeros((BATCH_SIZE, N_STATES), dtype=jnp.float32)
y_target = simulate_duffing_batch(ts, y0_batch, u_coeffs_batch)
jax.block_until_ready(y_target)
# Normalize targets per-trajectory, per-state (helps training stability)
sigma_y = jnp.std(y_target, axis=-1, keepdims=True)
y_target = y_target / sigma_y
model = NeuralODE(
n_states=N_STATES,
hidden_size=HIDDEN_SIZE,
depth=DEPTH,
rtol=1e-3,
atol=1e-6,
max_steps=1_000_000,
key=model_key,
)
n_params = sum(x.size for x in jax.tree.leaves(eqx.filter(model, eqx.is_array)))
print(f"Parameters: {n_params}")
# Warmup compile
for _ in range(n_warmup):
y0_batch = y_target[:, :, 0]
y_pred = predict(model, ts, u_coeffs_batch, y0_batch)
jax.block_until_ready(y_pred)
loss_val, _grads = loss_and_grad(model, ts, u_coeffs_batch, y_target)
jax.block_until_ready(loss_val)
# Benchmark
forward_times_s: list[float] = []
for _ in range(n_runs):
y0_batch = y_target[:, :, 0]
t0 = time.perf_counter()
y_pred = predict(model, ts, u_coeffs_batch, y0_batch)
jax.block_until_ready(y_pred)
forward_times_s.append(time.perf_counter() - t0)
grad_times_s: list[float] = []
for _ in range(n_runs):
t0 = time.perf_counter()
loss_val, _grads = loss_and_grad(model, ts, u_coeffs_batch, y_target)
jax.block_until_ready(loss_val)
grad_times_s.append(time.perf_counter() - t0)
forward_ms = 1000 * float(np.mean(forward_times_s))
grad_ms = 1000 * float(np.mean(grad_times_s))
print(f"Forward: {forward_ms:.2f} ms/batch")
print(f"Fwd+Bwd: {grad_ms:.2f} ms/batch")
# Training (full batch)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.nadam(LR))
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
t_train0 = time.perf_counter()
losses: list[float] = []
for epoch in range(1, N_EPOCHS + 1):
t0 = time.perf_counter()
loss, model, opt_state = train_step(model, opt_state, ts, u_coeffs_batch, y_target, optimizer)
jax.block_until_ready(loss)
losses.append(float(loss))
if epoch == 1 or epoch % 10 == 0:
print(
f"Epoch {epoch:3d}: loss={float(loss):.6g}, "
f"NRMSE={np.sqrt(float(loss)):.6g}, time={time.perf_counter() - t0:.3f}s"
)
total_s = time.perf_counter() - t_train0
print(f"Train total: {total_s:.2f}s ({total_s / N_EPOCHS:.3f}s/epoch)")
final_loss = float(losses[-1])
return {"forward_ms": forward_ms, "grad_ms": grad_ms, "final_nrmse": float(np.sqrt(final_loss))}
if __name__ == "__main__":
run_mwe()
Jax output:
========================================================================
Batch size: 16
Duration: 1.000s
Sampling Frequency: 96000.0 Hz
Time points: 96001
Backend: cpu
Device: TFRT_CPU_0
========================================================================
Parameters: 4546
Forward: 184.39 ms/batch
Fwd+Bwd: 673.00 ms/batch
Epoch 1: loss=1.03887, NRMSE=1.01925, time=1.390s
Epoch 10: loss=1.01628, NRMSE=1.00811, time=1.781s
Epoch 20: loss=1.01607, NRMSE=1.008, time=2.712s
Epoch 30: loss=1.01617, NRMSE=1.00805, time=2.297s
Epoch 40: loss=1.01596, NRMSE=1.00795, time=2.992s
Epoch 50: loss=1.01573, NRMSE=1.00784, time=3.419s
Train total: 119.58s (2.392s/epoch)
Sidenote: I’m no Enzyme or reactant wizard so i couldnt get the benefit of either! Enzyme didnt produce any nonzero gradients so loss was constant and reactant didnt play well with ode solver (as far as i understood it) ![]()
Also: the julia example works fine for smaller sampling rates which indicates the problem is the 96k sampled points in tvec!
If anybody has any feedback,ideas, experience or working mwe’s im all ears and it would be awsome! ![]()
Heres my Julia version (for the reactant and enzyme stuff i tried it on v. 1.11.7):
julia> versioninfo()
Julia Version 1.12.2
Commit ca9b6662be4 (2025-11-20 16:25 UTC)
Build Info:
Official https://julialang.org release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 16 × AMD Ryzen AI 7 350 w/ Radeon 860M
WORD_SIZE: 64
LLVM: libLLVM-18.1.7 (ORCJIT, znver5)
GC: Built with stock GC
Threads: 1 default, 1 interactive, 1 GC (on 16 virtual cores)
pkg> status
pkg> status
Project Loudspeaker v0.1.0
Status `~/loudspeaker-thesis/Loudspeaker.jl/Project.toml`
[47edcb42] ADTypes v1.20.0
[21141c5a] AMDGPU v2.1.4
[c7e460c6] ArgParse v1.2.0
[6e4b80f9] BenchmarkTools v1.6.3
[336ed68f] CSV v0.10.15
[052768ef] CUDA v5.9.5
[13f3f980] CairoMakie v0.15.8
[d360d2e6] ChainRulesCore v1.26.0
[b0b7db55] ComponentArrays v0.15.30
[992eb4ea] CondaPkg v0.2.33
⌃ [3abffc1c] ControlSystemIdentification v2.11.0
[a6e380b2] ControlSystems v1.15.1
⌅ [717857b8] DSP v0.7.10
[2445eb08] DataDrivenDiffEq v1.11.0
[5b588203] DataDrivenSparse v0.1.3
[a93c6f00] DataFrames v1.8.1
[82cc6244] DataInterpolations v8.8.0
⌃ [aae7a2af] DiffEqFlux v4.4.0
[071ae1c0] DiffEqGPU v3.9.0
[163ba53b] DiffResults v1.1.0
[a0c0ee7d] DifferentiationInterface v0.7.12
[06fc5a27] DynamicQuantities v1.10.0
⌃ [7da242da] Enzyme v0.13.108
[7a1cc6ca] FFTW v1.10.0
⌅ [f6369f11] ForwardDiff v0.10.39
[f67ccb44] HDF5 v0.17.2
[40713840] IncompleteLU v0.2.1
⌃ [5903a43b] Infiltrator v1.9.4
⌅ [033835bb] JLD2 v0.5.15
⌃ [7ed4a6bd] LinearSolve v3.48.1
[bdcacae8] LoopVectorization v0.12.173
[c22f76e6] LoudspeakerModels v0.1.0 `lib/LoudspeakerModels`
[b2108857] Lux v1.27.1
[d0bbae9a] LuxCUDA v0.3.4
⌃ [eb30cadb] MLDatasets v0.7.18
[f1d291b0] MLUtils v0.4.8
⌅ [961ee093] ModelingToolkit v10.31.1
[f162e290] ModelingToolkitNeuralNets v2.2.0
[16a59e39] ModelingToolkitStandardLibrary v2.25.0
⌃ [872c559c] NNlib v0.9.31
[6fd5a793] Octavian v0.3.29
[08131aa3] OpenCL v0.10.8
[429524aa] Optim v1.13.3
⌃ [3bd65402] Optimisers v0.4.6
[7f7a1694] Optimization v5.2.0
[dfa73e59] OptimizationODE v0.1.3
[36348300] OptimizationOptimJL v0.4.8
[42dfb2eb] OptimizationOptimisers v0.3.15
[500b13db] OptimizationPolyalgorithms v0.3.4
[892fee11] OptimizationSophia v1.2.1
[1dea7af3] OrdinaryDiffEq v6.103.0
⌃ [5960d6e9] OrdinaryDiffEqFIRK v1.16.0
[b1df2697] OrdinaryDiffEqTsit5 v1.5.0
[d7d3b36b] ParameterSchedulers v0.4.3
[91a5bcdd] Plots v1.41.2
[c3e4b0f8] Pluto v0.20.21
[98d1487c] PolyesterForwardDiff v0.1.3
[d236fae5] PreallocationTools v0.4.34
[6099a3de] PythonCall v0.9.30
⌃ [3c362404] Reactant v0.2.181
[731186ca] RecursiveArrayTools v3.39.0
[189a3867] Reexport v1.2.2
⌃ [7c2d2b1e] ReservoirComputing v0.12.3
[37e2e3b7] ReverseDiff v1.16.1
[295af30f] Revise v3.12.3
[1bc83da4] SafeTestsets v0.1.0
[0bca4576] SciMLBase v2.128.0
[1ed8b502] SciMLSensitivity v7.90.0
[53ae85a6] SciMLStructures v1.7.0
[df1fea92] SignalAnalysis v0.10.4
[de6bee2f] SimpleChains v0.4.7
[90137ffa] StaticArrays v1.9.15
[2913bbd2] StatsBase v0.34.9
[f3b207a7] StatsPlots v0.15.8
[2efcf032] SymbolicIndexingInterface v0.3.46
⌅ [0c5d862f] Symbolics v6.58.0
[b1e4fcf2] TestSignals v0.1.0 `lib/TestSignals`
[ac1d9e8a] ThreadsX v0.1.12
[1986cc42] Unitful v1.27.0
[e88e6eb3] Zygote v0.7.10
[ade2ca70] Dates v1.11.0
[37e2e46d] LinearAlgebra v1.12.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`