Hi @ChrisRackauckas, I went ahead to use the Ensemble interface for training the UDE.
Here is a sample code for the same. However, I am facing issues. Kindly help me out
using CSV, DataFrames, XLSX
using Optimization, OptimizationPolyalgorithms,OptimizationOptimisers, OptimizationOptimJL, LineSearches, OptimizationSpeedMapping
using ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity
using DataInterpolations, RecursiveArrayTools, LinearAlgebra, DataDrivenSparse, Statistics
using OrdinaryDiffEq, DifferentialEquations, Plots, DiffEqParamEstim, Sundials
using ForwardDiff, OptimizationOptimJL, OptimizationBBO, OptimizationNLopt
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()
rng = StableRNG(111)
tspan = (0.0, 12.0)
p_ = [0.5616710187190637, 0.00099999981807704, 5.7506499515119997e-5, 5.8624656433374085,
0.00038678047665933183, 0.10894746625822622]
# Multilayer FeedForward
const U1 = Lux.Chain(Lux.Dense(1, 15, tanh), Lux.Dense(15, 5, tanh), Lux.Dense(5, 5, tanh),
Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p1, st1 = Lux.setup(rng, U1)
const _st1 = st1
# Multilayer FeedForward
const U2 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p2, st2 = Lux.setup(rng, U2)
const _st2 = st2
# Multilayer FeedForward
const U3 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p3, st3 = Lux.setup(rng, U3)
const _st3 = st3
p = [p1, p2, p3]
# Define the hybrid model
function ude_dynamics(du, u, p, t)
û1 = U1([u[1]], p[1], _st1)[1] # Network prediction
û2 = U2([u[2]], p[2], _st2)[1] # Network prediction
û3 = U3([u[3]], p[3], _st3)[1] # Network prediction
du[1] = ((p_[1] * u[1]) .+ û1[1]) .+ ((f1(t)./f4(t)).*u[1])
du[2] = û2[1] .- (f1(t)./f4(t))
du[3] = (-û3[1] .* u[1]) .+ ((f2(t)./f4(t))) .+ ((f3(t)/f4(t)))
end
# Closure with the known parameter
nn_dynamics(du, u, p, t) = ude_dynamics(du, u, p, t,)
# Define the problem
u0 = [1.03, 0, 30.0]
prob_nn = ODEProblem(ude_dynamics, u0, tspan)
function train_prob_func(prob_nn, i , repeat)
f1 = ConstantInterpolation(
train_feed_data[i][:, 4], train_feed_data[i][:, 1], extrapolate=true)
f2 = ConstantInterpolation(
train_feed_data[i][:, 3], train_feed_data[i][:, 1], extrapolate=true)
f3 = ConstantInterpolation(
train_feed_data[i][:, 2], train_feed_data[i][:, 1], extrapolate=true)
f4 = ConstantInterpolation(
train_feed_data[i][:, 5], train_feed_data[i][:, 1], extrapolate=true)
tstops = train_data[i][:, 1]
function condition(u,t,integrator)
t in tstops
end
function affect!(integrator)
integrator.u[1] -= 0.020
end
cb = DiscreteCallback(condition, affect!; save_positions=(true,false))
remake(prob_nn; u0 = train_ic[i],
tspan = (train_tspan_data[i][1], train_tspan_data[i][2]),
p = (p1 , p2 , p3 , f1, f2, f3, f4),
tstops = tstops,
callback = cb)
end
t = train_data[1][:, 1]
function predict(p, X = u0, T = t)
newprob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = p)
enprob = EnsembleProblem(newprob, prob_func = train_prob_func)
sim = solve(enprob, Tsit5(), trajectories = size(train_batches, 1), saveat = T,
abstol = 1e-6, reltol = 1e-6,
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
return sim
end
function loss(θ)
sim = predict(θ)
l = mean(abs2, train_data .- sim)
end
losses = Float64[]
function callback(p, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, _) -> loss(x), adtype)
p_initial = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
optprob = Optimization.OptimizationProblem(optf, p_initial)
res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(), callback = callback, maxiters = 15000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
Here is the error message that I get running the code. Can you help me how should I pass NN parameters, other model parameters, and data interpolation objects (f1, f2, f3, f4) together in the EnsembleProb function
ERROR: Adjoint sensitivity analysis functionality requires being able to solve
a differential equation defined by the parameter struct `p`. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with adjoint sensitivity analysis requires that `p` could be a valid
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during adjoint differentiation.
To work around this issue for complicated cases like nested structs, look
into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type.