How to fit a FitzHugh-Nagumo model with DiffEqFlux.jl?

I try to apply the DiffEqFlux to the FitzHugh-Nagumo model, an excitable neuron model.

Basically, I only replace the differential equations in this tutorial (Neural Ordinary Differential Equations with sciml_train) with the FitzHugh-Nagumo model and use Rosenbrock23() solver.

However, I can’t get a good result (see the last figure).
I don’t know where it goes wrong or where I can improve it.
I will appreciate any suggestion or reference. Thank you!

Here is my code:

using DiffEqFlux, DifferentialEquations, Plots


function fitzhugh_nagumo!(du, u, p, t)
    v, w = u
    a, b, c = p
    du[1] = v * (a - v) * (v - 1) - w
    du[2] = b * v - c * w
end


# ground truth data
u0 = [0.5, 0.0]
p_true = [-0.1, 0.01, 0.02]
tspan = (0.0, 200.0)
datasize = 100
tsteps = range(tspan[1], tspan[2], length=datasize)

prob_trueode = ODEProblem(fitzhugh_nagumo!, u0, tspan, p_true)
ode_data = Array(solve(prob_trueode, Rosenbrock23(), saveat=tsteps))

image

dudt2 = FastChain(FastDense(2, 32, tanh),
                  FastDense(32, 2))

prob_neuralode = NeuralODE(dudt2, tspan, Rosenbrock23(), saveat=tsteps)


function loss_neuralode(p)
    pred = Array(prob_neuralode(u0, p))
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end


callback(p, l, pred; doplot=true) = begin
    display(l)

    # plot current prediction against data
    plt = plot(tsteps, ode_data', label = "data")
    scatter!(plt, tsteps, pred', label = "prediction")
    if doplot
        display(plot(plt))
    end

    return false
end

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p,
                                          cb=callback, maxiters=200)

image

Did you try multiple shooting?

1:
Thank you, Chris! Multiple shooting works! (although most trials fell into local minima, I’m still trying to improve it.)
ezgif.com-resize

But even in successful trials, the training process still raised an error:
I’m not sure what is wrong, so I put all the error messages here.

1.6615035429439715  # loss from callback function
1.6608717651905747
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
  convert(::Type{T}, ::Static.StaticFloat64{N}) where {N, T<:AbstractFloat} at ~/.julia/packages/Static/pkxBE/src/float.jl:26
  convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/execution.jl:39
  convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/core/value/constant.jl:103
  ...
Stacktrace:
  [1] fill!(dest::Vector{Float32}, x::Nothing)
    @ Base ./array.jl:351
  [2] copyto!
    @ ./broadcast.jl:921 [inlined]
  [3] materialize!
    @ ./broadcast.jl:871 [inlined]
  [4] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(identity), Tuple{Base.RefValue{Nothing}}})
    @ Base.Broadcast ./broadcast.jl:868
  [5] (::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/function/zygote.jl:8
  [6] (::GalacticOptim.var"#144#152"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}}}, GalacticOptim.var"#143#151"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}}}, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}}, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}})(G::Vector{Float32}, θ::Vector{Float32})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:93
  [7] value_gradient!!(obj::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, x::Vector{Float32})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:82
  [8] value_gradient!(obj::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, x::Vector{Float32})
    @ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:69
  [9] value_gradient!(obj::Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}}, x::Vector{Float32})
    @ Optim ~/.julia/packages/Optim/wFOeG/src/Manifolds.jl:50
 [10] (::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}}, Vector{Float32}, Vector{Float32}, Vector{Float32}})(α::Float32)
    @ LineSearches ~/.julia/packages/LineSearches/Ki4c5/src/LineSearches.jl:84
 [11] (::LineSearches.HagerZhang{Float64, Base.RefValue{Bool}})(ϕ::Function, ϕdϕ::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}}, Vector{Float32}, Vector{Float32}, Vector{Float32}}, c::Float32, phi_0::Float32, dphi_0::Float32)
    @ LineSearches ~/.julia/packages/LineSearches/Ki4c5/src/hagerzhang.jl:139
 [12] HagerZhang
    @ ~/.julia/packages/LineSearches/Ki4c5/src/hagerzhang.jl:101 [inlined]
 [13] perform_linesearch!(state::Optim.BFGSState{Vector{Float32}, Matrix{Float32}, Float32, Vector{Float32}}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, d::Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}})
    @ Optim ~/.julia/packages/Optim/wFOeG/src/utilities/perform_linesearch.jl:59
 [14] update_state!(d::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, state::Optim.BFGSState{Vector{Float32}, Matrix{Float32}, Float32, Vector{Float32}}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat})
    @ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/solvers/first_order/bfgs.jl:139
 [15] optimize(d::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, initial_x::Vector{Float32}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, options::Optim.Options{Float64, GalacticOptim.var"#_cb#150"{var"#17#19", BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}}}, state::Optim.BFGSState{Vector{Float32}, Matrix{Float32}, Float32, Vector{Float32}})
    @ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/optimize/optimize.jl:54
 [16] optimize(d::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, initial_x::Vector{Float32}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, options::Optim.Options{Float64, GalacticOptim.var"#_cb#150"{var"#17#19", BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}}})
    @ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/optimize/optimize.jl:36
 [17] ___solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}}}, opt::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; cb::Function, maxiters::Nothing, maxtime::Nothing, abstol::Nothing, reltol::Nothing, progress::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:129
 [18] #__solve#141
    @ ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:49 [inlined]
 [19] #solve#480
    @ ~/.julia/packages/SciMLBase/GW7GW/src/solve.jl:3 [inlined]
 [20] sciml_train(::typeof(loss_multiple_shooting), ::Vector{Float32}, ::Nothing, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Nothing, kwargs::Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}})
    @ DiffEqFlux ~/.julia/packages/DiffEqFlux/UPHk9/src/train.jl:111
 [21] top-level scope
    @ ~/project/odenet/test4.jl:115

2:
I wonder if the NeuralODE can be trained with different external inputs?
For example:

function fitzhugh_nagumo!(du, u, p, t)
    v, w = u
    a, b, c, input = p
    du[1] = v * (a - v) * (v - 1) - w + input
    du[2] = b * v - c * w
end

p1 = [-0.1, 0.01, 0.02, 0.0]
p2 = [-0.1, 0.01, 0.02, 0.1]

Then I use p1 to generate ode_data1 and p2 to generate ode_data2 (with the same u0 and tspan).
Is it possible to apply these two data to train the network?

Sorry for not checking these pages first. The problem and solution are really well-documented.

Just for the record, I put links here.

  1. Handling Divergent and Unstable Trajectories
  2. Stability and Divergence of ODE Solves