Hi @ChrisRackauckas, I have modified my code accordingly as we discussed.
However, there are errors relating to gradient calculations because of the abstract type, I guess. Please find the modified code and the error message below:
using CSV, DataFrames, XLSX
using Optimization, OptimizationPolyalgorithms,OptimizationOptimisers, OptimizationOptimJL, LineSearches, OptimizationSpeedMapping
using ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity
using DataInterpolations, RecursiveArrayTools, LinearAlgebra, DataDrivenSparse, Statistics
using OrdinaryDiffEq, DifferentialEquations, Plots, DiffEqParamEstim, Sundials
using ForwardDiff, OptimizationOptimJL, OptimizationBBO, OptimizationNLopt
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs
gr()
rng = StableRNG(111)
struct myf{T}
f1::T
f2::T
f3::T
f4::T
end
tspan = (0.0, 12.0)
p_ = [0.5616710187190637, 0.00099999981807704, 5.7506499515119997e-5, 5.8624656433374085,
0.00038678047665933183, 0.10894746625822622]
# Multilayer FeedForward
const U1 = Lux.Chain(Lux.Dense(1, 15, tanh), Lux.Dense(15, 5, tanh), Lux.Dense(5, 5, tanh),
Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p1, st1 = Lux.setup(rng, U1)
const _st1 = st1
# Multilayer FeedForward
const U2 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p2, st2 = Lux.setup(rng, U2)
const _st2 = st2
# Multilayer FeedForward
const U3 = Lux.Chain(Lux.Dense(1, 10, tanh), Lux.Dense(10, 5, tanh), Lux.Dense(5, 5, tanh),
Lux.Dense(5, 1))
# Get the initial parameters and state variables of the model
p3, st3 = Lux.setup(rng, U3)
const _st3 = st3
p = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
# Define the hybrid model
function (f::myf)(du, u, p, t)
û1 = U1([u[1]], p.p1, _st1)[1] # Network prediction
û2 = U2([u[2]], p.p2, _st2)[1] # Network prediction
û3 = U3([u[3]], p.p3, _st3)[1] # Network prediction
du[1] = ((p_[1] * u[1]) + û1[1]) + ((f.f1(t)/f.f4(t))*u[1])
du[2] = û2[1] - (f.f1(t)/f.f4(t))
du[3] = (-û3[1] * u[1]) + (f.f2(t)/f.f4(t)) + (f.f3(t)/f.f4(t))
end
# Define the problem
u0 = [1.03, 0, 30.0]
prob_nn = ODEProblem(f, u0, tspan, p)
function train_prob_func(prob_nn, i , repeat)
f1 = ConstantInterpolation(x[i][:, 1], x[i][:, 5], extrapolate=true)
f2 = ConstantInterpolation(x[i][:, 2], x[i][:, 5], extrapolate=true)
f3 = ConstantInterpolation(x[i][:, 3], x[i][:, 5], extrapolate=true)
f4 = ConstantInterpolation(x[i][:, 4], x[i][:, 5], extrapolate=true)
tstops = x[i][:, 6]
f = flows(f1, f2, f3, f4)
function condition(u,t,integrator)
t in tstops
end
function affect!(integrator)
integrator.u[1] -= 0.020
end
cb = DiscreteCallback(condition, affect!; save_positions=(true,false))
remake(prob_nn; f, tstops=tstops, callback = cb)
end
t = x[1][:, 6]
function predict(p, X = u0, T = t)
newprob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = p)
enprob = EnsembleProblem(newprob, prob_func = train_prob_func)
sim = solve(enprob, Tsit5(), trajectories = size(train_batches, 1), saveat = T,
abstol = 1e-6, reltol = 1e-6,
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)))
return sim
end
function loss(θ)
sim = predict(θ)
l = mean(abs2, train_data .- sim)
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
p_initial = ComponentArray(p1 = p1, p2 = p2, p3 = p3)
optprob = Optimization.OptimizationProblem(optf, p_initial)
res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(), maxiters = 15000)
Error Message:
LoadError: MethodError: no method matching (::ODESolution{…})(::Vector{…}, ::Type{…}, ::Nothing, ::Symbol)
The object of type `ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5,
15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50,
ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75,
Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30,
Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)),
layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64,
1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias
= 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)),
layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol}, @NamedTuple{tstops::Vector{Any}, callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{var"#condition#1"{Vector{Any}}, SciMLSensitivity.TrackedAffect{Float64, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, var"#affect!#2", Nothing, Int64}, typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Nothing}}}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, OrdinaryDiffEqCore.InterpolationData{ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368,
Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias
= 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)),
layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(p1 = ViewAxis(1:146, Axis(layer_1 = ViewAxis(1:30, Axis(weight = ViewAxis(1:15, ShapedAxis((15, 1))), bias = 16:30)), layer_2 = ViewAxis(31:110, Axis(weight = ViewAxis(1:75, ShapedAxis((5, 15))), bias = 76:80)), layer_3 = ViewAxis(111:140, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(141:146, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p2 = ViewAxis(147:257, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias = 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3 = ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5,
5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))), p3 = ViewAxis(258:368, Axis(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1))), bias
= 11:20)), layer_2 = ViewAxis(21:75, Axis(weight = ViewAxis(1:50, ShapedAxis((5, 10))), bias = 51:55)), layer_3
= ViewAxis(76:105, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = 26:30)), layer_4 = ViewAxis(106:111, Axis(weight = ViewAxis(1:5, ShapedAxis((1, 5))), bias = 6:6)))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, Nothing, OrdinaryDiffEqTsit5.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, Nothing}, SciMLBase.DEStats, Nothing, Nothing, Nothing, Nothing}` exists, but no method is defined for this combination of argument types when trying to treat it as a callable object.
Closest candidates are:
(::SciMLBase.AbstractODESolution)(::AbstractVector{<:Number}, ::Type{deriv}, ::Nothing, ::Any) where deriv
@ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:239
(::SciMLBase.AbstractODESolution)(::AbstractVector{<:Number}, ::Type{deriv}, ::Any, ::Any) where deriv
@ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:327
(::SciMLBase.AbstractODESolution)(::Number, ::Type{deriv}, ::Nothing, ::Any) where deriv
@ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:234
...
Stacktrace:
[1] (::ODESolution{…})(t::Vector{…}, ::Type{…}; idxs::Nothing, continuity::Symbol)
@ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:224
[2] (::ODESolution{…})(t::Vector{…}, ::Type{…})
@ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\solutions\ode_solutions.jl:219
[3] out_and_ts(_ts::Vector{…}, duplicate_iterator_times::Nothing, sol::ODESolution{…})
@ SciMLSensitivity C:\Users\AGARWP34\.julia\packages\SciMLSensitivity\ME3jV\src\adjoint_common.jl:668
[4] _concrete_solve_adjoint(::ODEProblem{…}, ::Tsit5{…}, ::QuadratureAdjoint{…}, ::Vector{…}, ::ComponentVector{…}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Vector{…}, save_idxs::Nothing, kwargs::@Kwargs{…})
@ SciMLSensitivity C:\Users\AGARWP34\.julia\packages\SciMLSensitivity\ME3jV\src\concrete_solve.jl:479
[5] _concrete_solve_adjoint
@ C:\Users\AGARWP34\.julia\packages\SciMLSensitivity\ME3jV\src\concrete_solve.jl:361 [inlined]
[6] #_solve_adjoint#75
@ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1555 [inlined]
[7] _solve_adjoint
@ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1528 [inlined]
[8] #rrule#4
@ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\ext\DiffEqBaseChainRulesCoreExt.jl:26 [inlined]
[9] rrule
@ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\ext\DiffEqBaseChainRulesCoreExt.jl:22 [inlined]
[10] rrule
@ C:\Users\AGARWP34\.julia\packages\ChainRulesCore\6Pucz\src\rules.jl:144 [inlined]
[11] chain_rrule_kw
@ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\chainrules.jl:236 [inlined]
[12] macro expansion
@ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0 [inlined]
[13] _pullback
@ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:87 [inlined]
[14] _apply
@ .\boot.jl:946 [inlined]
[15] adjoint
@ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\lib\lib.jl:203 [inlined]
[16] _pullback
@ C:\Users\AGARWP34\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
[17] #solve#51
@ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1015 [inlined]
[18] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::QuadratureAdjoint{…}, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Tsit5{…})
@ Zygote C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[19] _apply
@ .\boot.jl:946 [inlined]
[20] adjoint
@ C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\lib\lib.jl:203 [inlined]
[21] _pullback
@ C:\Users\AGARWP34\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]
[22] solve
@ C:\Users\AGARWP34\.julia\packages\DiffEqBase\frOsk\src\solve.jl:1005 [inlined]
[23] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…},
::Tsit5{…})
@ Zygote C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[24] #batch_func#653
@ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:193 [inlined]
[25] _pullback(::Zygote.Context{…}, ::SciMLBase.var"##batch_func#653", ::@Kwargs{…}, ::typeof(SciMLBase.batch_func), ::Int64, ::EnsembleProblem{…}, ::Tsit5{…})
@ Zygote C:\Users\AGARWP34\.julia\packages\Zygote\nyzjS\src\compiler\interface2.jl:0
[26] batch_func
@ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:180 [inlined]
[27] #662
@ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:252 [inlined]
[28] (::SciMLBaseZygoteExt.var"#138#141"{Zygote.Context{…}, SciMLBase.var"#662#663"{…}})(args::Int64)
@ SciMLBaseZygoteExt C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\ext\SciMLBaseZygoteExt.jl:264
[29] responsible_map(f::Function, II::UnitRange{Int64})
@ SciMLBase C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\src\ensemble\basic_ensemble_solve.jl:245
[30] ∇responsible_map
@ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\ext\SciMLBaseZygoteExt.jl:264 [inlined]
[31] adjoint
@ C:\Users\AGARWP34\.julia\packages\SciMLBase\CMjVZ\ext\SciMLBaseZygoteExt.jl:291 [inlined]
[32] _pullback
@ C:\Users\AGARWP34\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:67 [inlined]