Hybrid ODE with ContinuousCallback for models of changing sizes?

I am solving a system of ODEs with a ContinuousCallback where once a condition!() is met, certain members of the state vector are removed (i.e., elements disappear during simulation). I would like to solve this system using DiffEqFlux.jl as a Neural Differential Equation, specifically a Universal Differential Equation. I am having trouble getting this working, and I believe the problem is my custom ContinuousCallback due to there being a “cannot resize array with shared data” error. Does anyone know if the functionality I’ve described is supported in DiffEqFlux.jl?

Here’s a minimal example which I have adapted from @ChrisRackauckas’s universal_differential_equations repo (hope that’s okay):

using OrdinaryDiffEq, ModelingToolkit, LinearAlgebra, ComponentArrays, Optimization, OptimizationOptimisers, OptimizationOptimJL, Lux,
Plots, ComponentArrays, DiffEqFlux, JLD2, FileIO, Statistics, Random, Distributions, LinearSolve

# Set a random seed for reproduceable behaviour
using Random
rng = Random.default_rng()
Random.seed!(1234)

#### NOTE
# Since the recent release of DataDrivenDiffEq v0.6.0 where a complete overhaul of the optimizers took
# place, SR3 has been used. Right now, STLSQ performs better and has been changed.

# Create a name for saving ( basically a prefix )
svname = "Scenario_1_"

# === Add a ContinuousCallback to test behaviour
# Callback condition for removing elements when value is at threshold
function condition(u, t, integrator)
    # Trigger the event if the order of magnitude of state variable goes below some arbitrary limit...need to pick something that will for-sure get triggered within the integration
    minimum(u) < 0.5 ? 0 : 1
end

# Callback to modify the state vector if above condition is met
function affect!(integrator)
    original_size = length(integrator.u)
    idxs = findall(x -> x <= 0.5, integrator.u)
    new_size = original_size - length(idxs)

    # Remove the identified elements that have below our threshold
    deleteat!(integrator.u, idxs)
    deleteat_non_user_cache!(integrator, idxs)

    resize!(integrator, new_size)
    resize_non_user_cache!(integrator, new_size)
    nothing
end

## Data generation
function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α*u[1] - β*u[2]*u[1]
    du[2] = γ*u[1]*u[2]  - δ*u[2]
end

# Define the experimental parameter
disappearing_callback = ContinuousCallback(condition, affect!)
tspan = (0.0,3.0)
u0 = [0.44249296,4.6280594]
p_ = [1.3, 0.9, 0.8, 1.8]

# This works fine
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1, callback=disappearing_callback)

# Ideal data
X = Array(solution)
t = solution.t
DX = Array(solution(solution.t, Val{1}))

# Add noise in terms of the mean
x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
display(scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing]))

## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))

# Multilayer FeedForward
U = Lux.Chain(
    Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,2)
)
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
    û = U(u, p, st)[1] # Network prediction
    du[1] = p_true[1]*u[1] + û[1]
    du[2] = -p_true[4]*u[2] + û[2]
end

# Closure with the known parameter
nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_)
# Define the problem
disappearing_callback = ContinuousCallback(condition, affect!)
prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)

## Function to train the network
# Define a predictor
# Add the callbacks
function predict!(θ, X = Xₙ[:,1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Rosenbrock23(linsolve = LUFactorization()), saveat = T,
                abstol=1e-6, reltol=1e-6, callback=disappearing_callback))
end

# Simple L2 loss
function loss(θ)
    X̂ = predict!(θ)
    sum(abs2, Xₙ .- X̂)
end

# Container to track the losses
losses = Float64[]

callback = function (p, l)
  push!(losses, l)
  if length(losses)%50==0
      println("Current loss after $(length(losses)) iterations: $(losses[end])")
  end
  return false
end

## Training

# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

# This throws the error
res1 = Optimization.solve(optprob, ADAM(0.1), callback=callback, maxiters = 5)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

I don’t think diffeqflux will natively work with changing state sizes. A typical solution in the ML side of things for such problems is to mask your states by multiplying them with zeros for the removed states

Direct reverse mode should do it though. Try sensealg=TrackerAdjoint.

Thanks so much for the advice…and so quickly! I added sensealg=SciMLSensitivity.TrackerAdjoint to the solve inside the predict function, but now see the following error:

ERROR: Compiling Tuple{typeof(FunctionWrappers.gen_fptr), Type{Nothing}, Type{Tuple{Vector{Float64}, Vector{Float64}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:15, Axis(weight = ViewAxis(1:10, ShapedAxis((5, 2))), bias = ViewAxis(11:15, ShapedAxis((5, 1))))), layer_2 = ViewAxis(16:45, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))), layer_3 = ViewAxis(46:75, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5))), bias = ViewAxis(26:30, ShapedAxis((5, 1))))), layer_4 = ViewAxis(76:87, Axis(weight = ViewAxis(1:10, ShapedAxis((2, 5))), bias = ViewAxis(11:12, ShapedAxis((2, 1))))))}}}, Float64}}, Type{SciMLBase.Void{typeof(nn_dynamics!)}}}: UndefVarError: `spvals` not defined

(I could also share the stacktrace, but its longer than the character limit for replies)

Did you happen to see the same thing, or is this a “me” problem?

Good to know! I see here in the docs that there are some methods compatible with hybrid ODEs with events.

Try doing ODEProblem{true, SciMLBase.FullSpecialize}(...) to work around that.

Thanks! I tried that! New error, which I think means we are making progress :slightly_smiling_face:

ArgumentError: new: too few arguments (expected 3)

I am trying to follow where that error is coming from…I think it’s Zygote?

Stack trace?

Of course! Truncated because the full stack trace is over the character limit.

ERROR: ArgumentError: new: too few arguments (expected 3)
Stacktrace:
  [1] __new__
    @ ~/.julia/packages/Zygote/nsBv0/src/tools/builtins.jl:9 [inlined]
  [2] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:296 [inlined]
  [3] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [4] BitArray
    @ ./bitarray.jl:39 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::Type{BitMatrix}, ::UndefInitializer, ::Int64, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [6] generate_chunked_partials
    @ ~/.julia/packages/SparseDiffTools/CPCma/src/differentiation/compute_jacobian_ad.jl:84 [inlined]
  [7] _pullback(::Zygote.Context{…}, ::typeof(SparseDiffTools.generate_chunked_partials), ::Vector{…}, ::UnitRange{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [8] #ForwardColorJacCache#13
    @ ~/.julia/packages/SparseDiffTools/CPCma/src/differentiation/compute_jacobian_ad.jl:37 [inlined]
  [9] _pullback(::Zygote.Context{…}, ::SparseDiffTools.var"##ForwardColorJacCache#13", ::Nothing, ::Type{…}, ::UnitRange{…}, ::Nothing, ::Type{…}, ::SciMLBase.UJacobianWrapper{…}, ::Vector{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [10] ForwardColorJacCache
    @ ~/.julia/packages/SparseDiffTools/CPCma/src/differentiation/compute_jacobian_ad.jl:22 [inlined]
 [11] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::Type{…}, ::SciMLBase.UJacobianWrapper{…}, ::Vector{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [12] build_jac_config
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/derivative_wrappers.jl:292 [inlined]
 [13] _pullback(::Zygote.Context{…}, ::typeof(OrdinaryDiffEq.build_jac_config), ::Rosenbrock23{…}, ::ODEFunction{…}, ::SciMLBase.UJacobianWrapper{…}, ::Vector{…}, ::Vector{…}, ::Vector{…}, ::Vector{…}, ::Vector{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [14] alg_cache
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/caches/rosenbrock_caches.jl:108 [inlined]
 [15] _pullback(::Zygote.Context{…}, ::typeof(alg_cache), ::Rosenbrock23{…}, ::Vector{…}, ::Vector{…}, ::Type{…}, ::Type{…}, ::Type{…}, ::Vector{…}, ::Vector{…}, ::ODEFunction{…}, ::Float64, ::Float64, ::Float64, ::ComponentVector{…}, ::Bool, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [16] #__init#806
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:344 [inlined]
 [17] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__init#806", ::Vector{…}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::ContinuousCallback{…}, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{…}, ::Float64, ::Float64, ::Rational{…}, ::Int64, ::Int64, ::Rational{…}, ::Nothing, ::Nothing, ::Rational{…}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Symbol, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::@Kwargs{}, ::typeof(SciMLBase.__init), ::ODEProblem{…}, ::Rosenbrock23{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] __init
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [19] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__init), ::ODEProblem{…}, ::Rosenbrock23{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] __init (repeats 4 times)
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [21] _apply
    @ ./boot.jl:838 [inlined]
 [22] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [23] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [24] #__solve#805
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:6 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__solve#805", ::@Kwargs{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [27] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [29] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:1 [inlined]
 [30] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [31] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [32] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [33] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [34] #solve_call#44
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612 [inlined]
 [35] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_call#44", ::Bool, ::Nothing, ::@Kwargs{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [36] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [37] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [38] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [39] solve_call
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:569 [inlined]
 [40] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [41] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080 [inlined]
 [42] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_up#53", ::@Kwargs{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Type{…}, ::Vector{…}, ::ComponentVector{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [43] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [44] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [45] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [46] solve_up
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1066 [inlined]
 [47] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Type{…}, ::Vector{…}, ::ComponentVector{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [48] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [49] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [50] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [51] #solve#51
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003 [inlined]
 [52] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::Type{…}, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [53] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [54] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [55] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [56] solve
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:993 [inlined]
 [57] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [58] predict!
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:101 [inlined]
 [59] _pullback(::Zygote.Context{…}, ::typeof(predict!), ::ComponentVector{…}, ::Vector{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [60] predict!
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:96 [inlined]
 [61] _pullback(ctx::Zygote.Context{false}, f::typeof(predict!), args::ComponentVector{Float64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [62] loss
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:107 [inlined]
 [63] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [64] #33
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:127 [inlined]
 [65] _pullback(::Zygote.Context{…}, ::var"#33#34", ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [66] _apply
    @ ./boot.jl:838 [inlined]
 [67] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [68] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [69] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:3762 [inlined]
 [70] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [71] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [72] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [73] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [74] #37
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
 [75] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [76] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [77] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [78] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [79] #39
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
 [80] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#39#57"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [81] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [82] pullback
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [83] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
 [84] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
 [85] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [86] macro expansion
    @ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
 [87] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [88] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:188
 [89] solve(::OptimizationProblem{…}, ::Optimisers.Adam; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:96
Some type information was truncated. Use `show(err)` to see complete types.

Do you need a method for stiff equations? Switch that to Vern7 and how does it do?

In my actual system, I am solving a kinetic model, so the stiff solver is very much needed…I did try switching it to Vern7 though and it still threw the same error, but with a different number!

ERROR: ArgumentError: new: too few arguments (expected 49)

As opposed to (expected 3) with Rosenbrock23. So, that’s interesting!

What’s the error message?

The error message is:

ERROR: ArgumentError: new: too few arguments (expected 49)

No, the full message. You’re cutting off the part with all of the information.

sorry!

ERROR: ArgumentError: new: too few arguments (expected 49)
Stacktrace:
  [1] __new__(::Type, ::ODESolution{…}, ::Vararg{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/tools/builtins.jl:9
  [2] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:296 [inlined]
  [3] adjoint(::Zygote.Context{…}, ::typeof(Zygote.__new__), ::Type, ::ODESolution{…}, ::Vector{…}, ::Nothing, ::Vector{…}, ::Vararg{…})
    @ Zygote ./none:0
  [4] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [5] ODEIntegrator
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/integrators/type.jl:168 [inlined]
  [6] _pullback(::Zygote.Context{…}, ::Type{…}, ::ODESolution{…}, ::Vector{…}, ::Nothing, ::Vector{…}, ::Float64, ::Float64, ::ODEFunction{…}, ::ComponentVector{…}, ::Vector{…}, ::Vector{…}, ::Nothing, ::Float64, ::Vern7{…}, ::Float64, ::Bool, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Float64, ::Int64, ::Int64, ::Int64, ::Int64, ::OrdinaryDiffEq.Vern7Cache{…}, ::Nothing, ::Int64, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::Int64, ::Float64, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DEOptions{…}, ::SciMLBase.DEStats, ::OrdinaryDiffEq.DefaultInit, ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [7] #__init#806
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:468 [inlined]
  [8] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__init#806", ::Vector{…}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::ContinuousCallback{…}, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{…}, ::Float64, ::Float64, ::Rational{…}, ::Int64, ::Int64, ::Int64, ::Nothing, ::Nothing, ::Rational{…}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Symbol, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::@Kwargs{}, ::typeof(SciMLBase.__init), ::ODEProblem{…}, ::Vern7{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [9] __init
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [10] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__init), ::ODEProblem{…}, ::Vern7{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [11] __init (repeats 4 times)
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [12] _apply
    @ ./boot.jl:838 [inlined]
 [13] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [14] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [15] #__solve#805
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:6 [inlined]
 [16] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__solve#805", ::@Kwargs{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [17] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [18] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [19] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [20] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:1 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [22] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [23] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [24] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [25] #solve_call#44
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612 [inlined]
 [26] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_call#44", ::Bool, ::Nothing, ::@Kwargs{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [27] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [28] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [29] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [30] solve_call
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:569 [inlined]
 [31] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [32] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080 [inlined]
 [33] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_up#53", ::@Kwargs{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Type{…}, ::Vector{…}, ::ComponentVector{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [34] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [35] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [36] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [37] solve_up
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1066 [inlined]
 [38] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Type{…}, ::Vector{…}, ::ComponentVector{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [39] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [40] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [41] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [42] #solve#51
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003 [inlined]
 [43] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::Type{…}, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [44] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [45] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [46] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [47] solve
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:993 [inlined]
 [48] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Vern7{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [49] predict!
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:101 [inlined]
 [50] _pullback(::Zygote.Context{…}, ::typeof(predict!), ::ComponentVector{…}, ::Vector{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [51] predict!
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:96 [inlined]
 [52] _pullback(ctx::Zygote.Context{false}, f::typeof(predict!), args::ComponentVector{Float64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [53] loss
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:107 [inlined]
 [54] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [55] #39
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/example.jl:127 [inlined]
 [56] _pullback(::Zygote.Context{…}, ::var"#39#40", ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [57] _apply
    @ ./boot.jl:838 [inlined]
 [58] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [59] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [60] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:3762 [inlined]
 [61] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [62] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [63] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [64] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [65] #37
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
 [66] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [67] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [68] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [69] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [70] #39
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
 [71] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#39#57"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [72] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [73] pullback
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [74] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
 [75] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
 [76] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [77] macro expansion
    @ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
 [78] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [79] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:188
 [80] solve(::OptimizationProblem{…}, ::Optimisers.Adam; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:96
Some type information was truncated. Use `show(err)` to see complete types.

Can you make a more minimal example? That would help with debugging it. It’s due to a constructor adjoint of sorts. I don’t think you need neural networks or anything, just the changing size callbacks.

So solving this more minimal example problem as a normal ODE with changing size callbacks works (as far as I can tell):

using DifferentialEquations
using Random
using Distributions
using Plots
using LinearSolve
rng = Random.default_rng()
Random.seed!(1010)

# ODE Function
function particle_radius_change_dynamics(du, u, p, t)
    rate_constants = p

    R = u

    for i in eachindex(R)
        du[i] = rate_constants[i] * (1.0/u[i] - 1/mean_size) 
    end

    return nothing

end

# Callback condition for removing particles when radius is at/close to 0
function condition(u, t, integrator)
    # Trigger the event if the order of magnitude of the radius goes below the threshold
    minimum(u) < 1e-4 ? 0 : 1
end

# Callback to modify the state vector if above condition is met
function affect!(integrator)
    println("You've entered the callback")
    original_size = length(integrator.u)
    idxs = findall(r -> r <= 1e-4, integrator.u)
    new_size = original_size - length(idxs)

    # Remove the identified radii that have effectively gone to zero
    deleteat!(integrator.u, idxs)
    deleteat_non_user_cache!(integrator, idxs)

    resize!(integrator, new_size)
    resize_non_user_cache!(integrator, new_size)
    nothing
end

# === PROBLEM SETUP === 
# Number of particles, and the mean size and standard deviation of particle sizes
num_particles = 11
mean_size = 1.2
std = 0.1*mean_size

# Generate random initial particle size distribution 
initial_radii = rand(Normal(mean_size, std), num_particles)

# Generate random rate rate constants for each particle (growing or shrinking -- usually would follow conservation laws but this is just an example)
random_rates = rand(Uniform(-1,1),num_particles)

# Make input vectors
u0 = initial_radii
p_true = random_rates
tspan = (0, 100)

# Make ODE problem
prob = ODEProblem(particle_radius_change_dynamics, u0, tspan, p_true)

# Callback to remove a particle from the simulation if it effectively "disappears" due to shrinking
disappearing_callback = ContinuousCallback(condition, affect!)

# === SOLVE ===
@time sol = solve(prob, Rosenbrock23(linsolve = LUFactorization()), 
maxiters=1e6, abstol=1e-5, reltol=1e-5,   isoutofdomain=(u,p,t)->any(x->x<0.0, u), callback=disappearing_callback)

println("You've solved the ODE")
# See number of particles decrease over time 
display(plot(sol.t, map((x) -> length(x), sol[:]), lw = 3,
    ylabel = "Number of Nanoparticles", xlabel = "Time"))

But then if you add on the next step, of trying to solve this as a UDE, I run into the aforementioned error.

using DifferentialEquations
using Random
using Distributions
using Plots
using LinearSolve
rng = Random.default_rng()
Random.seed!(1010)

# ODE Function
function particle_radius_change_dynamics(du, u, p, t)
    rate_constants = p

    R = u

    for i in eachindex(R)
        du[i] = rate_constants[i] * (1.0/u[i] - 1/mean_size) 
    end

    return nothing

end

# Callback condition for removing particles when radius is at/close to 0
function condition(u, t, integrator)
    # Trigger the event if the order of magnitude of the radius goes below the threshold
    minimum(u) < 1e-4 ? 0 : 1
end

# Callback to modify the state vector if above condition is met
function affect!(integrator)
    println("You've entered the callback")
    original_size = length(integrator.u)
    idxs = findall(r -> r <= 1e-4, integrator.u)
    new_size = original_size - length(idxs)

    # Remove the identified radii that have effectively gone to zero
    deleteat!(integrator.u, idxs)
    deleteat_non_user_cache!(integrator, idxs)

    resize!(integrator, new_size)
    resize_non_user_cache!(integrator, new_size)
    nothing
end

# === PROBLEM SETUP === 
# Number of particles, and the mean size and standard deviation of particle sizes
num_particles = 11
mean_size = 1.2
std = 0.1*mean_size

# Generate random initial particle size distribution 
initial_radii = rand(Normal(mean_size, std), num_particles)

# Generate random rate rate constants for each particle (growing or shrinking -- usually would follow conservation laws but this is just an example)
random_rates = rand(Uniform(-1,1),num_particles)

# Make input vectors
u0 = initial_radii
p_true = random_rates
tspan = (0, 100)

# Make ODE problem
prob = ODEProblem(particle_radius_change_dynamics, u0, tspan, p_true)

# Callback to remove a particle from the simulation if it effectively "disappears" due to shrinking
disappearing_callback = ContinuousCallback(condition, affect!)

# === SOLVE ===
@time sol = solve(prob, Rosenbrock23(linsolve = LUFactorization()), 
maxiters=1e6, abstol=1e-5, reltol=1e-5,   isoutofdomain=(u,p,t)->any(x->x<0.0, u), callback=disappearing_callback)

println("You've solved the ODE")
# See number of particles decrease over time 
display(plot(sol.t, map((x) -> length(x), sol[:]), lw = 3,
    ylabel = "Number of Nanoparticles", xlabel = "Time"))

# ==== TEST THE UDE SET UP
using OrdinaryDiffEq, ModelingToolkit, LinearAlgebra, ComponentArrays, Optimization, OptimizationOptimisers, OptimizationOptimJL, Lux,
ComponentArrays, DiffEqFlux, JLD2, FileIO, Statistics, SciMLSensitivity

new_u = Array(sol.u)

# == Need to fill in the missing data with NaN so that its a matrix we can plot/generate data
function pad_vectors_with_NaN(vectors::Vector{Vector{Float64}}, M::Int)
    N = length(vectors)
    matrix = fill(NaN, N, M)
    
    for i in 1:N
        length_vec = min(length(vectors[i]), M)
        matrix[i, 1:length_vec] = vectors[i][1:length_vec]
    end
    
    return matrix
end

X = pad_vectors_with_NaN(new_u, length(sol.t))
t = sol.t

relative_factor = 1e-4
magnitude = relative_factor * abs.(X)
noise = rand(Normal(0,1), size(X))
scaled_noise = magnitude .* noise
Xₙ = X .+ scaled_noise

plt = plot(t, X[:,1], alpha = 0.75, color = :black, label = ["True Data" nothing], title="Single trajectory")
scatter!(plt, t, Xₙ[:,1], color = :red, label = ["Noisy Data" nothing], markersize=2)
#display(plt)

## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))

# Multilayer FeedForward
U = Lux.Chain(
    Lux.Dense(num_particles,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,num_particles)
)
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
    û = U(u, p, st)[1] # Network prediction

    rate_constants = p_true

    R = u

    for i in eachindex(R)
        du[i] = rate_constants[i] * (1.0/u[i]) + û[i]
    end

    return nothing
end

# Closure with the known parameter
nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_true)
# Define the problem
disappearing_callback = ContinuousCallback(condition, affect!)
prob_nn = ODEProblem{true, SciMLBase.FullSpecialize}(nn_dynamics!, Xₙ[:, 1], tspan, p)

## Function to train the network
# Define a predictor
function predict!(θ, X = Xₙ[:,1], T = t)
    #_prob = ODEProblem{true, SciMLBase.FullSpecialize}(nn_dynamics!, X, (T[1], T[end]), θ)
    println("You're trying to make a prediction for the UDE...")
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Rosenbrock23(linsolve = LUFactorization()), saveat = T,
                abstol=1e-6, reltol=1e-6, sensealg=SciMLSensitivity.TrackerAdjoint, callback=disappearing_callback))
    println("You've solved the UDE")
end

# Simple L2 loss
function loss(θ)
    X̂ = predict!(θ)
    sum(abs2, Xₙ .- X̂)
end

# Container to track the losses
losses = Float64[]

callback = function (p, l)
  push!(losses, l)
  if length(losses)%50==0
      println("Current loss after $(length(losses)) iterations: $(losses[end])")
  end
  return false
end

## Training
# First train with ADAM for better convergence -> move the parameters into a
# favourable starting positing for BFGS
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(0.1), callback=callback, maxiters = 5)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

The output and error being:

You've entered the callback
You've entered the callback
You've entered the callback
You've entered the callback
  0.501317 seconds (1.60 M allocations: 97.729 MiB, 3.67% gc time, 98.97% compilation time: 3% of which was recompilation)
You've solved the ODE
You're trying to make a prediction for the UDE...
ERROR: ArgumentError: new: too few arguments (expected 3)
Stacktrace:
  [1] __new__
    @ ~/.julia/packages/Zygote/nsBv0/src/tools/builtins.jl:9 [inlined]
  [2] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:296 [inlined]
  [3] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
  [4] BitArray
    @ ./bitarray.jl:39 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::Type{BitMatrix}, ::UndefInitializer, ::Int64, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [6] generate_chunked_partials
    @ ~/.julia/packages/SparseDiffTools/CPCma/src/differentiation/compute_jacobian_ad.jl:84 [inlined]
  [7] _pullback(::Zygote.Context{…}, ::typeof(SparseDiffTools.generate_chunked_partials), ::Vector{…}, ::UnitRange{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
  [8] #ForwardColorJacCache#13
    @ ~/.julia/packages/SparseDiffTools/CPCma/src/differentiation/compute_jacobian_ad.jl:37 [inlined]
  [9] _pullback(::Zygote.Context{…}, ::SparseDiffTools.var"##ForwardColorJacCache#13", ::Nothing, ::Type{…}, ::UnitRange{…}, ::Nothing, ::Type{…}, ::SciMLBase.UJacobianWrapper{…}, ::Vector{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [10] ForwardColorJacCache
    @ ~/.julia/packages/SparseDiffTools/CPCma/src/differentiation/compute_jacobian_ad.jl:22 [inlined]
 [11] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::Type{…}, ::SciMLBase.UJacobianWrapper{…}, ::Vector{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [12] build_jac_config
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/derivative_wrappers.jl:292 [inlined]
 [13] _pullback(::Zygote.Context{…}, ::typeof(OrdinaryDiffEq.build_jac_config), ::Rosenbrock23{…}, ::ODEFunction{…}, ::SciMLBase.UJacobianWrapper{…}, ::Vector{…}, ::Vector{…}, ::Vector{…}, ::Vector{…}, ::Vector{…}, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [14] alg_cache
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/caches/rosenbrock_caches.jl:108 [inlined]
 [15] _pullback(::Zygote.Context{…}, ::typeof(alg_cache), ::Rosenbrock23{…}, ::Vector{…}, ::Vector{…}, ::Type{…}, ::Type{…}, ::Type{…}, ::Vector{…}, ::Vector{…}, ::ODEFunction{…}, ::Float64, ::Float64, ::Float64, ::ComponentVector{…}, ::Bool, ::Val{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [16] #__init#806
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:344 [inlined]
 [17] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__init#806", ::Vector{…}, ::Tuple{}, ::Tuple{}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Nothing, ::ContinuousCallback{…}, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{…}, ::Float64, ::Float64, ::Rational{…}, ::Int64, ::Int64, ::Rational{…}, ::Nothing, ::Nothing, ::Rational{…}, ::Nothing, ::Bool, ::Int64, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Symbol, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::OrdinaryDiffEq.DefaultInit, ::@Kwargs{}, ::typeof(SciMLBase.__init), ::ODEProblem{…}, ::Rosenbrock23{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [18] __init
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [19] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__init), ::ODEProblem{…}, ::Rosenbrock23{…}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [20] __init (repeats 4 times)
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:11 [inlined]
 [21] _apply
    @ ./boot.jl:838 [inlined]
 [22] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [23] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [24] #__solve#805
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:6 [inlined]
 [25] _pullback(::Zygote.Context{…}, ::OrdinaryDiffEq.var"##__solve#805", ::@Kwargs{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [26] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [27] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [28] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [29] __solve
    @ ~/.julia/packages/OrdinaryDiffEq/tAI61/src/solve.jl:1 [inlined]
 [30] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(SciMLBase.__solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [31] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [32] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [33] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [34] #solve_call#44
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:612 [inlined]
 [35] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_call#44", ::Bool, ::Nothing, ::@Kwargs{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [36] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [37] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [38] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [39] solve_call
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:569 [inlined]
 [40] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [41] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1080 [inlined]
 [42] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve_up#53", ::@Kwargs{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Type{…}, ::Vector{…}, ::ComponentVector{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [43] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [44] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [45] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [46] solve_up
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1066 [inlined]
 [47] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{…}, ::Type{…}, ::Vector{…}, ::ComponentVector{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [48] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [49] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [50] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [51] #solve#51
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:1003 [inlined]
 [52] _pullback(::Zygote.Context{…}, ::DiffEqBase.var"##solve#51", ::Type{…}, ::Nothing, ::Nothing, ::Val{…}, ::@Kwargs{…}, ::typeof(solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [53] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [54] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [55] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [56] solve
    @ ~/.julia/packages/DiffEqBase/yM6LF/src/solve.jl:993 [inlined]
 [57] _pullback(::Zygote.Context{…}, ::typeof(Core.kwcall), ::@NamedTuple{…}, ::typeof(solve), ::ODEProblem{…}, ::Rosenbrock23{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [58] predict!
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/minimal_example.jl:147 [inlined]
 [59] _pullback(::Zygote.Context{…}, ::typeof(predict!), ::ComponentVector{…}, ::Vector{…}, ::Vector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [60] predict!
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/minimal_example.jl:145 [inlined]
 [61] _pullback(ctx::Zygote.Context{false}, f::typeof(predict!), args::ComponentVector{Float64, Vector{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [62] loss
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/minimal_example.jl:154 [inlined]
 [63] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [64] #41
    @ ~/Library/CloudStorage/OneDrive-PNNL/Documents/Projects/ACCELERATE/sintering_kinetic_model/minimal_example.jl:173 [inlined]
 [65] _pullback(::Zygote.Context{…}, ::var"#41#42", ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [66] _apply
    @ ./boot.jl:838 [inlined]
 [67] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [68] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [69] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/JUp1I/src/scimlfunctions.jl:3762 [inlined]
 [70] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [71] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [72] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [73] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [74] #37
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
 [75] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [76] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [77] adjoint
    @ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:203 [inlined]
 [78] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [79] #39
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
 [80] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#39#57"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
 [81] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:90
 [82] pullback
    @ ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:88 [inlined]
 [83] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:147
 [84] (::OptimizationZygoteExt.var"#38#56"{…})(::ComponentVector{…}, ::ComponentVector{…})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
 [85] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [86] macro expansion
    @ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
 [87] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [88] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:188
 [89] solve(::OptimizationProblem{…}, ::Optimisers.Adam; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/JUp1I/src/solve.jl:96
Some type information was truncated. Use `show(err)` to see complete types.```

You didn’t put the sensealg in there?

@time sol = solve(prob, Rosenbrock23(linsolve = LUFactorization()), 
maxiters=1e6, abstol=1e-5, reltol=1e-5,   isoutofdomain=(u,p,t)->any(x->x<0.0, u), callback=disappearing_callback, sensealg = ReverseDiffAdjoint())

or

@time sol = solve(prob, Rosenbrock23(linsolve = LUFactorization()), 
maxiters=1e6, abstol=1e-5, reltol=1e-5,   isoutofdomain=(u,p,t)->any(x->x<0.0, u), callback=disappearing_callback, sensealg = TrackerAdjoint())

?

I didn’t include the sensealg in the first ODE solve (for data generation) because that ODE solve works without issue. Its the solve statement within the predict!() function that is causing the issues. And I do have the sensealg in that solve:

function predict!(θ, X = Xₙ[:,1], T = t)
    #_prob = ODEProblem{true, SciMLBase.FullSpecialize}(nn_dynamics!, X, (T[1], T[end]), θ)
    println("You're trying to make a prediction for the UDE...")
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Rosenbrock23(linsolve = LUFactorization()), saveat = T,
                abstol=1e-6, reltol=1e-6, sensealg=SciMLSensitivity.TrackerAdjoint, callback=disappearing_callback))
    println("You've solved the UDE")
end

And running this example leads to the aforementioned error.

Sorry if this is confusing, let me know what I can clarify.

Share the most recent version of the code in a way that I can copy paste. I was going to pull it local today to debug it but have kind of lost what the latest version to test is here.