Hi! Im new to Julia and SciML but I am really curious to see how combining ODE’s with neural network could improve the predictive performance of the model we are developing for solvent-based CO2 capture. My current setup includes
- dynamic states,
- multiple independent timeseries of real data
- exogenous inputs,
- unknown dynamics modelled as a neural network (NN).
I use DataInterpolations.jl to handle the inputs. I want to train the NN by minimising a loss function (MSE across all the independent time series).
I have made a MWE (tried to make it as minimal as possible) with 5 series of data, each with ~200 datapoints.
My initial challenge is the following: Computing the gradient of the loss function with AD (using Zygote.jl or ForwardDiff.jl) takes long time. It seems like I am missing something or doing something wrong.
using BenchmarkTools
@btime Zygote.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)
@btime ForwardDiff.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)
outputs:
7.955 s (64378604 allocations: 3.56 GiB)
173.575 ms (5457036 allocations: 734.76 MiB)
which seems like a lot for just computing the gradient..?! Below is the MWE - maybe someone can spot something that is clearly set up in a stupid/inefficient way.
The true system is (for data generation):
dx_1/dt = -\alpha * (x_1 - u_1) + \sin(\gamma * x_2)*u_2
dx_2/dt = -\beta *(x_2 - u_3)
and the system with the neural network is:
dx_1/dt = -\alpha * (x_1 - u_1) + NN(x_2,u_2)
dx_2/dt = -\beta *(x_2 - u_3)
i.e. a system which is partially known and partially unknown. x is the states, u are the exogenous inputs.
Below is the code:
Import of pkg and data generation:
# import packages
import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
import ComponentArrays
import DataInterpolations as DI
using Zygote
using ForwardDiff
# Standard Libraries
import Statistics
# External Libraries
import Lux
import StableRNGs
import StaticArrays as SA
using Plots
# --- Types defintions ---
struct Series
t::Vector{Float64} # length T
Y::Matrix{Float64} # (n_y, T)
U_funcs::Vector{DI.ConstantInterpolation}
X0::NTuple{2,Float64} # initial state (clA, clD)
tspan::Tuple{Float64,Float64}
end
struct Dataset
series::Vector{Series}
end
## Generate simulated data ##
# Number of data series
N_series = 5
# Time vector for each series (different lengths)
tf = rand(5.0:10.0, N_series) # final time between 5 and 10 h
#dt = 30/3600 # time step of 30 sec
dt = 2/60
t = [0.0:dt:tf[i] for i in 1:N_series] # time vector for each series
# Input functions for each series
# input values:
u1 = [[1.0, 0.5, 2.5], [0.8, 0.7, 2.0], [1.2, 0.4, 3.0], [1.1, 0.6, 2.8], [0.9, 0.55, 2.2]]
u2 = [[2.0, 1.5, 3.5], [1.8, 1.7, 3.0], [2.2, 1.4, 4.0], [2.1, 1.6, 3.8], [1.9, 1.55, 3.2]]
u3 = [[3.0, 2.5, 4.5], [2.8, 2.7, 4.0], [3.2, 2.4, 5.0], [3.1, 2.6, 4.8], [2.9, 2.55, 4.2]]
# Time points of steps:
tu1 = [[0.0, tf[i]/2, 4tf[i]/5] for i in 1:N_series]
tu2 = [[0.0, tf[i]/3, 2tf[i]/3] for i in 1:N_series]
tu3 = [[0.0, tf[i]/4, 3tf[i]/4] for i in 1:N_series]
# Create DataInterpolations.jl ConstantInterpolation functions for each series
U_funcs = Vector{DI.ConstantInterpolation}[]
for i in 1:N_series
push!(U_funcs, [
DI.ConstantInterpolation(u1[i], tu1[i]; extrapolation = DI.ExtrapolationType.Constant),
DI.ConstantInterpolation(u2[i], tu2[i]; extrapolation = DI.ExtrapolationType.Constant),
DI.ConstantInterpolation(u3[i], tu3[i]; extrapolation = DI.ExtrapolationType.Constant)
])
end
# Create ODE problem and solve to generate data
p_true = [0.5, 0.3, 3.0] # true parameters
# True system equation
function ffun(x, p, t, u)
alpha, beta, gamma = p
dx1 = -alpha * (x[1] - u[1](t)) + sin(x[2]*gamma) * u[2](t)
dx2 = -beta * (x[2] - u[3](t))
dx = [dx1, dx2]
return(dx)
end
# Initial conditions for each series
x0_series = [[2.0, 1.0], [1.5, 1.5], [2.5, 0.5], [2.2, 1.2], [1.8, 0.8]]
# Generate data for each series
X = Matrix{Float64}[]
for i in 1:N_series
prob = ODE.ODEProblem((x, p, t) -> ffun(x, p, t, [U_funcs[i][1], U_funcs[i][2], U_funcs[i][3]]),
x0_series[i], (0.0, tf[i]), p_true)
sol = ODE.solve(prob, ODE.Tsit5(), saveat = t[i])
push!(X, Array(sol))
end
# Add noise to the data in each series
rng = StableRNGs.StableRNG(1234)
Y = Matrix{Float64}[]
for i in 1:N_series
ȳ = Statistics.mean(X[i], dims = 2)
noise_magnitude = 1e-2
Y_noise = X[i] .+ (noise_magnitude * ȳ) .* randn(rng, eltype(X[i]), size(X[i]))
push!(Y, Y_noise)
end
# Build Dataset
dataset = Dataset(Series[])
for i in 1:N_series
s = Series(t[i], Y[i], U_funcs[i], (x0_series[i][1], x0_series[i][2]), (0.0, tf[i]))
push!(dataset.series, s)
end
Define model and loss function and compute the gradient:
# --- Model Definition ---
# Neural Network
neural_net = Lux.Chain(
Lux.Dense(2, 10, x -> exp.(-(x.^2))),
Lux.Dense(10, 10, x -> exp.(-(x.^2))),
Lux.Dense(10, 1) # linear head
)
rng = StableRNGs.StableRNG(1234)
p0, st = Lux.setup(rng, neural_net)
const _st = st
# in-place ude_dynamics!
function ude_dynamics!(dx, x, p, t, U_funcs)
# NN part
input_NN = [x[2], U_funcs[2](t)]
NN = neural_net(input_NN, p, _st) # Lux forward pass
dx[1] = -0.5 * (x[1] - U_funcs[1](t)) + NN[1]
dx[2] = -0.3 * (x[2] - U_funcs[3](t))
return nothing
end
# out-of-place ude_dynamics
function ude_dynamics(x, p, t, U_funcs)
# NN part
input_NN = [x[2], U_funcs[2](t)]
NN, _ = neural_net(input_NN, p, _st) # Lux forward pass
dx1 = -0.5 * (x[1] - U_funcs[1](t)) + NN[1]
dx2 = -0.3 * (x[2] - U_funcs[3](t))
return [dx1, dx2]
end
# Build ODE problems
function make_problems(ds::Dataset, θ0)
prob_list = ODE.ODEProblem[]
for s in ds.series
# in-place ude
# f!(du, u, θ, t) = ude_dynamics!(du, u, θ, t, s.U_funcs, ctx)
# push!(prob_list, ODE.ODEProblem(f!, collect(s.X0), s.tspan, θ0))
# out-of-place ude
f(u, θ, t) = ude_dynamics(u, θ, t, s.U_funcs)
push!(prob_list, ODE.ODEProblem(f, collect(s.X0), s.tspan, θ0))
end
return prob_list
end
# Define loss function over all series
function loss_multi(θ, prob_list, dataset, solver, sensealg1)
loss = 0.0
for i in eachindex(prob_list)
_prob = ODE.remake(prob_list[i], p = θ)
sol = ODE.solve(_prob, solver;
saveat = dataset.series[i].t,
save_everystep = false, dense = false,
abstol = 1e-6, reltol = 1e-6,
sensealg = sensealg1)
X̂ = Array(sol) # (2, T)
loss += Statistics.mean(abs2, dataset.series[i].Y .- X̂)
end
return loss
end
# Create ODE problems for all series
prob_list = make_problems(dataset, θ0)
# Solver and sensitivity algorithm
solver = ODE.Tsit5()
sensealg1 = SMS.QuadratureAdjoint(autojacvec = SMS.ZygoteVJP())
_θ0 = ComponentArrays.ComponentVector{Float64}(θ0)
using BenchmarkTools
@btime Zygote.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)
@btime ForwardDiff.gradient(θ -> loss_multi(θ, prob_list, dataset, solver, sensealg1), _θ0)

