Universal Differential Equations with Ensemble Problem

I am facing issues trying to develop UDE model using the Ensemble interface

Kindly help me out. Here is the sample code and error I get

using CSV, DataFrames, XLSX
using Optimization, OptimizationPolyalgorithms,OptimizationOptimisers, OptimizationOptimJL, LineSearches, OptimizationSpeedMapping
using ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity
using DataInterpolations, RecursiveArrayTools, LinearAlgebra, DataDrivenSparse, Statistics
using OrdinaryDiffEq, DifferentialEquations, Plots, DiffEqParamEstim, Sundials
using ForwardDiff, OptimizationOptimJL, OptimizationBBO, OptimizationNLopt
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()

rng = StableRNG(111)

tspan = (0.0, 12.0)

p_ = [0.5616710187190637, 0.00099999981807704, 5.7506499515119997e-5, 5.8624656433374085,
0.00038678047665933183, 0.10894746625822622]

# Multilayer FeedForward
const U1 = Lux.Chain(Lux.Dense(1, 15, tanh), Lux.Dense(15, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p1, st1 = Lux.setup(rng, U1)
const _st1 = st1

# Multilayer FeedForward
const U2 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p2, st2 = Lux.setup(rng, U2)
const _st2 = st2

# Multilayer FeedForward
const U3 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
              Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p3, st3 = Lux.setup(rng, U3)
const _st3 = st3

p = [p1, p2, p3]

# Define the hybrid model
function ude_dynamics(du, u, p, t)
    û1 = U1([u[1]], p[1], _st1)[1] # Network prediction
    û2 = U2([u[2]], p[2], _st2)[1] # Network prediction
    û3 = U3([u[3]], p[3], _st3)[1] # Network prediction

    du[1] = ((p_[1] * u[1]) .+ û1[1]) .+ ((f1(t)./f4(t)).*u[1])
    du[2] = û2[1] .- (f1(t)./f4(t))
    du[3] = (-û3[1] .* u[1]) .+ ((f2(t)./f4(t))) .+ ((f3(t)/f4(t)))
end

# Closure with the known parameter
nn_dynamics(du, u, p, t) = ude_dynamics(du, u, p, t,)
# Define the problem
u0 = [1.03, 0, 30.0]
prob_nn = ODEProblem(ude_dynamics, u0, tspan)

function train_prob_func(prob_nn, i , repeat)
    f1 = ConstantInterpolation(
        train_feed_data[i][:, 4], train_feed_data[i][:, 1], extrapolate=true)
    f2 = ConstantInterpolation(
        train_feed_data[i][:, 3], train_feed_data[i][:, 1], extrapolate=true) 
    f3 = ConstantInterpolation(
        train_feed_data[i][:, 2], train_feed_data[i][:, 1], extrapolate=true)
    f4 = ConstantInterpolation(
        train_feed_data[i][:, 5], train_feed_data[i][:, 1], extrapolate=true)

    tstops = train_data[i][:, 1]

    function condition(u,t,integrator) 
        t in tstops
    end
    
    function affect!(integrator)
        integrator.u[1] -= 0.020
    end
    
    cb = DiscreteCallback(condition, affect!; save_positions=(true,false))
    remake(prob_nn; u0 = train_ic[i], 
           tspan = (train_tspan_data[i][1], train_tspan_data[i][2]),
           p = (p1 , p2 , p3 , f1, f2, f3, f4), 
           tstops = tstops, 
           callback = cb)
end


t = train_data[1][:, 1]

function predict(p, X = u0, T = t)
    newprob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = p)
    enprob = EnsembleProblem(newprob, prob_func = train_prob_func)
    sim = solve(enprob, Tsit5(), trajectories = size(train_batches, 1),  saveat = T,
                 abstol = 1e-6, reltol = 1e-6,
                 sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
    return sim
end

function loss(θ)
    sim = predict(θ)
    l = mean(abs2, train_data .- sim)
end

losses = Float64[]

function callback(p, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, _) -> loss(x), adtype)
p_initial = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
optprob = Optimization.OptimizationProblem(optf, p_initial)


res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(), callback = callback, maxiters = 15000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

Here is the error message that I get running the code. Can you help me how should I pass NN parameters, other model parameters, and data interpolation objects (f1, f2, f3, f4) together in the EnsembleProb function

ERROR: Adjoint sensitivity analysis functionality requires being able to solve  
a differential equation defined by the parameter struct `p`. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with adjoint sensitivity analysis requires that `p` could be a valid
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during adjoint differentiation.     
To work around this issue for complicated cases like nested structs, look       
into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type.

An array of named tuples won’t be a valid parameter. Instead, make that into a component array like the tutorials with multiple neural nets show.

I have tried doing that. The issue is I cannot pack interpolated objects f1,f2,f3 and f4 long with p1,p2 and p3 in an component array (refer train_prob_func)