Universal Differential Equations with Ensemble Problem

I am facing issues trying to develop UDE model using the Ensemble interface

Kindly help me out. Here is the sample code and error I get

using CSV, DataFrames, XLSX
using Optimization, OptimizationPolyalgorithms,OptimizationOptimisers, OptimizationOptimJL, LineSearches, OptimizationSpeedMapping
using ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity
using DataInterpolations, RecursiveArrayTools, LinearAlgebra, DataDrivenSparse, Statistics
using OrdinaryDiffEq, DifferentialEquations, Plots, DiffEqParamEstim, Sundials
using ForwardDiff, OptimizationOptimJL, OptimizationBBO, OptimizationNLopt
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()

rng = StableRNG(111)

tspan = (0.0, 12.0)

p_ = [0.5616710187190637, 0.00099999981807704, 5.7506499515119997e-5, 5.8624656433374085,
0.00038678047665933183, 0.10894746625822622]

# Multilayer FeedForward
const U1 = Lux.Chain(Lux.Dense(1, 15, tanh), Lux.Dense(15, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p1, st1 = Lux.setup(rng, U1)
const _st1 = st1

# Multilayer FeedForward
const U2 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p2, st2 = Lux.setup(rng, U2)
const _st2 = st2

# Multilayer FeedForward
const U3 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p3, st3 = Lux.setup(rng, U3)
const _st3 = st3

p = [p1, p2, p3]

# Define the hybrid model
function ude_dynamics(du, u, p, t)
    û1 = U1([u[1]], p[1], _st1)[1] # Network prediction
    û2 = U2([u[2]], p[2], _st2)[1] # Network prediction
    û3 = U3([u[3]], p[3], _st3)[1] # Network prediction

    du[1] = ((p_[1] * u[1]) .+ û1[1]) .+ ((f1(t)./f4(t)).*u[1])
    du[2] = û2[1] .- (f1(t)./f4(t))
    du[3] = (-û3[1] .* u[1]) .+ ((f2(t)./f4(t))) .+ ((f3(t)/f4(t)))
end

# Closure with the known parameter
nn_dynamics(du, u, p, t) = ude_dynamics(du, u, p, t,)
# Define the problem
u0 = [1.03, 0, 30.0]
prob_nn = ODEProblem(ude_dynamics, u0, tspan)

function train_prob_func(prob_nn, i , repeat)
    f1 = ConstantInterpolation(
        train_feed_data[i][:, 4], train_feed_data[i][:, 1], extrapolate=true)
    f2 = ConstantInterpolation(
        train_feed_data[i][:, 3], train_feed_data[i][:, 1], extrapolate=true) 
    f3 = ConstantInterpolation(
        train_feed_data[i][:, 2], train_feed_data[i][:, 1], extrapolate=true)
    f4 = ConstantInterpolation(
        train_feed_data[i][:, 5], train_feed_data[i][:, 1], extrapolate=true)

    tstops = train_data[i][:, 1]

    function condition(u,t,integrator) 
        t in tstops
    end
    
    function affect!(integrator)
        integrator.u[1] -= 0.020
    end
    
    cb = DiscreteCallback(condition, affect!; save_positions=(true,false))
    remake(prob_nn; u0 = train_ic[i], 
           tspan = (train_tspan_data[i][1], train_tspan_data[i][2]),
           p = (p1 , p2 , p3 , f1, f2, f3, f4), 
           tstops = tstops, 
           callback = cb)
end


t = train_data[1][:, 1]

function predict(p, X = u0, T = t)
    newprob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = p)
    enprob = EnsembleProblem(newprob, prob_func = train_prob_func)
    sim = solve(enprob, Tsit5(), trajectories = size(train_batches, 1),  saveat = T,
                 abstol = 1e-6, reltol = 1e-6,
                 sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sim
end

function loss(θ)
    sim = predict(θ)
    l = mean(abs2, train_data .- sim)
end

losses = Float64[]

function callback(p, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, _) -> loss(x), adtype)
p_initial = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
optprob = Optimization.OptimizationProblem(optf, p_initial)


res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(), callback = callback, maxiters = 15000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

Here is the error message that I get running the code. Can you help me how should I pass NN parameters, other model parameters, and data interpolation objects (f1, f2, f3, f4) together in the EnsembleProb function

ERROR: Adjoint sensitivity analysis functionality requires being able to solve  
a differential equation defined by the parameter struct `p`. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with adjoint sensitivity analysis requires that `p` could be a valid
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during adjoint differentiation.     
To work around this issue for complicated cases like nested structs, look       
into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type.

An array of named tuples won’t be a valid parameter. Instead, make that into a component array like the tutorials with multiple neural nets show.

I have tried doing that. The issue is I cannot pack interpolated objects f1,f2,f3 and f4 long with p1,p2 and p3 in an component array (refer train_prob_func)

Hi @ChrisRackauckas, I have modified my code accordingly as we discussed.

However, there are errors relating to gradient calculations because of the abstract type, I guess. Please find the modified code and the error message below:

using CSV, DataFrames, XLSX
using Optimization, OptimizationPolyalgorithms,OptimizationOptimisers, OptimizationOptimJL, LineSearches, OptimizationSpeedMapping
using ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity
using DataInterpolations, RecursiveArrayTools, LinearAlgebra, DataDrivenSparse, Statistics
using OrdinaryDiffEq, DifferentialEquations, Plots, DiffEqParamEstim, Sundials
using ForwardDiff, OptimizationOptimJL, OptimizationBBO, OptimizationNLopt
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()

rng = StableRNG(111)

struct myf{T}
    f1::T
    f2::T
    f3::T
    f4::T
end

tspan = (0.0, 12.0)

p_ = [0.5616710187190637, 0.00099999981807704, 5.7506499515119997e-5, 5.8624656433374085,
0.00038678047665933183, 0.10894746625822622]

# Multilayer FeedForward
const U1 = Lux.Chain(Lux.Dense(1, 15, tanh), Lux.Dense(15, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p1, st1 = Lux.setup(rng, U1)
const _st1 = st1

# Multilayer FeedForward
const U2 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p2, st2 = Lux.setup(rng, U2)
const _st2 = st2

# Multilayer FeedForward
const U3 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p3, st3 = Lux.setup(rng, U3)
const _st3 = st3

p = ComponentArray(p1 = p1, p2 = p2, p3 = p3)

# Define the hybrid model
function (f::myf)(du, u, p, t)
    û1 = U1([u[1]], p.p1, _st1)[1] # Network prediction
    û2 = U2([u[2]], p.p2, _st2)[1] # Network prediction
    û3 = U3([u[3]], p.p3, _st3)[1] # Network prediction

    du[1] = ((p_[1] * u[1]) + û1[1]) + ((f.f1(t)/f.f4(t))*u[1])
    du[2] = û2[1] - (f.f1(t)/f.f4(t))
    du[3] = (-û3[1] * u[1]) + (f.f2(t)/f.f4(t)) + (f.f3(t)/f.f4(t))
end

# Define the problem
u0 = [1.03, 0, 30.0]
prob_nn = ODEProblem(f, u0, tspan, p)

function train_prob_func(prob_nn, i , repeat)
    f1 = ConstantInterpolation(x[i][:, 1], x[i][:, 5], extrapolate=true)
    f2 = ConstantInterpolation(x[i][:, 2], x[i][:, 5], extrapolate=true) 
    f3 = ConstantInterpolation(x[i][:, 3], x[i][:, 5], extrapolate=true)
    f4 = ConstantInterpolation(x[i][:, 4], x[i][:, 5], extrapolate=true)

    tstops = x[i][:, 6]

    f = flows(f1, f2, f3, f4)

    function condition(u,t,integrator) 
        t in tstops
    end
    
    function affect!(integrator)
        integrator.u[1] -= 0.020
    end
    
    cb = DiscreteCallback(condition, affect!; save_positions=(true,false))
    remake(prob_nn; f, tstops=tstops, callback = cb)
end


t = x[1][:, 6]

function predict(p, X = u0, T = t)
    newprob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = p)
    enprob = EnsembleProblem(newprob, prob_func = train_prob_func)
    sim = solve(enprob, Tsit5(), trajectories = size(train_batches, 1),  saveat = T,
                 abstol = 1e-6, reltol = 1e-6,
                 sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sim
end

function loss(θ)
    sim = predict(θ)
    l = mean(abs2, train_data .- sim)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
p_initial = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
optprob = Optimization.OptimizationProblem(optf, p_initial)

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(), maxiters = 15000)

Error Message:

LoadError: MethodError: no method matching (::ODESolution{…})(::Vector{…}, ::Type{…}, ::Nothing, ::Symbol)
The object of type `ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 
15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, 
ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, 
Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, 
Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), 
layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 
1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias 
= 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), 
layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, @NamedTuple{tstops::Vector{Any}, callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{var"#condition#1"{Vector{Any}}, SciMLSensitivity.TrackedAffect{Float64, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, var"#affect!#2", Nothing, Int64}, typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Nothing}}}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, 
Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias 
= 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), 
layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 
5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias 
= 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 
= ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.

Closest candidates are:
  (::SciMLBase.AbstractODESolution)(::AbstractVector{<:Number}, ::Type{deriv}, ::Nothing, ::Any) where deriv    
   @ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:239
  (::SciMLBase.AbstractODESolution)(::AbstractVector{<:Number}, ::Type{deriv}, ::Any, ::Any) where deriv        
   @ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:327
  (::SciMLBase.AbstractODESolution)(::Number, ::Type{deriv}, ::Nothing, ::Any) where deriv
   @ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:234
  ...

Stacktrace:
  [1] (::ODESolution{…})(t::Vector{…}, ::Type{…}; idxs::Nothing, continuity::Symbol)
    @ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:224
  [2] (::ODESolution{…})(t::Vector{…}, ::Type{…})
    @ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:219
  [3] out_and_ts(_ts::Vector{…}, duplicate_iterator_times::Nothing, sol::ODESolution{…})
    @ SciMLSensitivity C:\Users\AGARWP34\.julia\packages\SciMLSensitivity\ME3jV\src\adjoint_common.jl:668       
  [4] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::QuadratureAdjoint{…}, ::Vector{…}, ::ComponentVector{…}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{…}, save_idxs::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity C:\Users\AGARWP34\.julia\packages\SciMLSensitivity\ME3jV\src\concrete_solve.jl:479       
  [5] _concrete_solve_adjoint
    @ C:\Users\AGARWP34\.julia\packages\SciMLSensitivity\ME3jV\src\concrete_solve.jl:361 [inlined]
  [6] #_solve_adjoint#75
    @ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1555 [inlined]
  [7] _solve_adjoint
    @ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1528 [inlined]
  [8] #rrule#4
    @ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\ext\DiffEqBaseChainRulesCoreExt.jl:26 [inlined]        
  [9] rrule
    @ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\ext\DiffEqBaseChainRulesCoreExt.jl:22 [inlined]        
 [10] rrule
    @ C:\Users\AGARWP34\.julia\packages\ChainRulesCore\6Pucz\src\rules.jl:144 [inlined]
 [11] chain_rrule_kw
    @ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\chainrules.jl:236 [inlined]
 [12] macro expansion
    @ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0 [inlined]
 [13] _pullback
    @ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:87 [inlined]
 [14] _apply
    @ .\boot.jl:946 [inlined]
 [15] adjoint
    @ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\lib\lib.jl:203 [inlined]
 [16] _pullback
    @ C:\Users\AGARWP34\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
 [17] #solve#51
    @ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1015 [inlined]
 [18] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::QuadratureAdjoint{…}, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
    @ Zygote C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [19] _apply
    @ .\boot.jl:946 [inlined]
 [20] adjoint
    @ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\lib\lib.jl:203 [inlined]
 [21] _pullback
    @ C:\Users\AGARWP34\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
 [22] solve
    @ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1005 [inlined]
 [23] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, 
::Tsit5{…})
    @ Zygote C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [24] #batch_func#653
    @ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:193 [inlined]      
 [25] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##batch_func#653", ::@Kwargs{…}, ::typeof(SciMLBase.batch_func), ::Int64, ::EnsembleProblem{…}, ::Tsit5{…})
    @ Zygote C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
 [26] batch_func
    @ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:180 [inlined]      
 [27] #662
    @ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:252 [inlined]      
 [28] (::SciMLBaseZygoteExt.var"#138#141"{Zygote.Context{…}, SciMLBase.var"#662#663"{…}})(args::Int64)
    @ SciMLBaseZygoteExt C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\ext\SciMLBaseZygoteExt.jl:264        
 [29] responsible_map(f::Function, II::UnitRange{Int64})
    @ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:245      
 [30] ∇responsible_map
    @ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\ext\SciMLBaseZygoteExt.jl:264 [inlined]
 [31] adjoint
    @ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\ext\SciMLBaseZygoteExt.jl:291 [inlined]
 [32] _pullback
    @ C:\Users\AGARWP34\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]

Lots of the imports aren’t necessary. I tried to clean it up but it doesn’t look like all variables are defined. This is what I have so far:

using Optimization, OptimizationOptimisers, SciMLSensitivity, DataInterpolations
using OrdinaryDiffEq, Plots, ComponentArrays, Lux, Zygote, StableRNGs
gr()

rng = StableRNG(111)

struct myf{T}
    f1::T
    f2::T
    f3::T
    f4::T
end

tspan = (0.0, 12.0)

p_ = [0.5616710187190637, 0.00099999981807704, 5.7506499515119997e-5, 5.8624656433374085,
0.00038678047665933183, 0.10894746625822622]

# Multilayer FeedForward
const U1 = Lux.Chain(Lux.Dense(1, 15, tanh), Lux.Dense(15, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p1, st1 = Lux.setup(rng, U1)
const _st1 = st1

# Multilayer FeedForward
const U2 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p2, st2 = Lux.setup(rng, U2)
const _st2 = st2

# Multilayer FeedForward
const U3 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p3, st3 = Lux.setup(rng, U3)
const _st3 = st3

p = ComponentArray(p1 = p1, p2 = p2, p3 = p3)

# Define the hybrid model
function (f::myf)(du, u, p, t)
    û1 = U1([u[1]], p.p1, _st1)[1] # Network prediction
    û2 = U2([u[2]], p.p2, _st2)[1] # Network prediction
    û3 = U3([u[3]], p.p3, _st3)[1] # Network prediction

    du[1] = ((p_[1] * u[1]) + û1[1]) + ((f.f1(t)/f.f4(t))*u[1])
    du[2] = û2[1] - (f.f1(t)/f.f4(t))
    du[3] = (-û3[1] * u[1]) + (f.f2(t)/f.f4(t)) + (f.f3(t)/f.f4(t))
end

# Define the problem
u0 = [1.03, 0, 30.0]
prob_nn = ODEProblem(f, u0, tspan, p)

function train_prob_func(prob_nn, i , repeat)
    f1 = ConstantInterpolation(x[i][:, 1], x[i][:, 5], extrapolate=true)
    f2 = ConstantInterpolation(x[i][:, 2], x[i][:, 5], extrapolate=true) 
    f3 = ConstantInterpolation(x[i][:, 3], x[i][:, 5], extrapolate=true)
    f4 = ConstantInterpolation(x[i][:, 4], x[i][:, 5], extrapolate=true)

    tstops = x[i][:, 6]

    f = flows(f1, f2, f3, f4)

    function condition(u,t,integrator) 
        t in tstops
    end
    
    function affect!(integrator)
        integrator.u[1] -= 0.020
    end
    
    cb = DiscreteCallback(condition, affect!; save_positions=(true,false))
    remake(prob_nn; f, tstops=tstops, callback = cb)
end


t = x[1][:, 6]

function predict(p, X = u0)
    newprob = remake(prob_nn, u0 = X, tspan = (0.0, 1.0), p = p)
    enprob = EnsembleProblem(newprob, prob_func = train_prob_func)
    sim = solve(enprob, Tsit5(), trajectories = size(train_batches, 1),  saveat = 0:0.1:1.0,
                 abstol = 1e-6, reltol = 1e-6,
                 sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sim
end

function loss(θ)
    sim = predict(θ)
    l = mean(abs2, train_data .- sim)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
p_initial = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
optprob = Optimization.OptimizationProblem(optf, p_initial)

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(), maxiters = 15000)

f is not defined though?

Hi @ChrisRackauckas, thank you for your response.

I have not defined the variable x. It is basically a vector of Array (each array has 6 columns) that has experimental data.
Further, f can be defined as it is defined in the function train_prob_func.

I have added few lines to the above code:

As I understand from the error that Zygote cannot calculate gradients for parameter optimization using Ensemble Interface that contains external interpolation functions.

using Optimization, OptimizationOptimisers, SciMLSensitivity, DataInterpolations
using OrdinaryDiffEq, Plots, ComponentArrays, Lux, Zygote, StableRNGs
gr()

rng = StableRNG(111)

struct myf{T}
    f1::T
    f2::T
    f3::T
    f4::T
end

tspan = (0.0, 12.0)

# Multilayer FeedForward
const U1 = Lux.Chain(Lux.Dense(1, 15, tanh), Lux.Dense(15, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p1, st1 = Lux.setup(rng, U1)
const _st1 = st1

# Multilayer FeedForward
const U2 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p2, st2 = Lux.setup(rng, U2)
const _st2 = st2

# Multilayer FeedForward
const U3 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p3, st3 = Lux.setup(rng, U3)
const _st3 = st3

p = ComponentArray(p1 = p1, p2 = p2, p3 = p3)

# Initializing f
i = 1
f1 = ConstantInterpolation(x[i][:, 1], x[i][:, 5], extrapolate=true)
f2 = ConstantInterpolation(x[i][:, 2], x[i][:, 5], extrapolate=true) 
f3 = ConstantInterpolation(x[i][:, 3], x[i][:, 5], extrapolate=true)
f4 = ConstantInterpolation(x[i][:, 4], x[i][:, 5], extrapolate=true)
f = flows(f1, f2, f3, f4)

# Define the hybrid model
function (f::myf)(du, u, p, t)
    û1 = U1([u[1]], p.p1, _st1)[1] # Network prediction
    û2 = U2([u[2]], p.p2, _st2)[1] # Network prediction
    û3 = U3([u[3]], p.p3, _st3)[1] # Network prediction

    du[1] = ((p_[1] * u[1]) + û1[1]) + ((f.f1(t)/f.f4(t))*u[1])
    du[2] = û2[1] - (f.f1(t)/f.f4(t))
    du[3] = (-û3[1] * u[1]) + (f.f2(t)/f.f4(t)) + (f.f3(t)/f.f4(t))
end

# Define the problem
u0 = [1.03, 0, 30.0]
prob_nn = ODEProblem(f, u0, tspan, p)

function train_prob_func(prob_nn, i , repeat)
    f1 = ConstantInterpolation(x[i][:, 1], x[i][:, 5], extrapolate=true)
    f2 = ConstantInterpolation(x[i][:, 2], x[i][:, 5], extrapolate=true) 
    f3 = ConstantInterpolation(x[i][:, 3], x[i][:, 5], extrapolate=true)
    f4 = ConstantInterpolation(x[i][:, 4], x[i][:, 5], extrapolate=true)

    tstops = x[i][:, 6]

    f = flows(f1, f2, f3, f4)

    function condition(u,t,integrator) 
        t in tstops
    end
    
    function affect!(integrator)
        integrator.u[1] -= 0.020
    end
    
    cb = DiscreteCallback(condition, affect!; save_positions=(true,false))
    remake(prob_nn; f, tstops=tstops, callback = cb)
end


t = x[1][:, 6]

function predict(p, X = u0)
    newprob = remake(prob_nn, u0 = X, tspan = (0.0, 12.0), p = p)
    enprob = EnsembleProblem(newprob, prob_func = train_prob_func)
    sim = solve(enprob, Tsit5(), trajectories = size(train_batches, 1),  saveat = 0:0.1:12.0,
                 abstol = 1e-6, reltol = 1e-6,
                 sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sim
end

function loss(θ)
    sim = predict(θ)
    l = mean(abs2, train_data .- sim)
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
p_initial = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
optprob = Optimization.OptimizationProblem(optf, p_initial)

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(), maxiters = 15000)

This test case is a bit in the weeds and thus in order to debug it I will need a runnable case. In your current code, flows is undefined and so is x, and so I’m not sure how to recreate what you’re seeing.