# How to fit a FitzHugh-Nagumo model with DiffEqFlux.jl?

I try to apply the DiffEqFlux to the FitzHugh-Nagumo model, an excitable neuron model.

Basically, I only replace the differential equations in this tutorial (Neural Ordinary Differential Equations with sciml_train) with the FitzHugh-Nagumo model and use `Rosenbrock23()` solver.

However, I can’t get a good result (see the last figure).
I don’t know where it goes wrong or where I can improve it.
I will appreciate any suggestion or reference. Thank you!

Here is my code:

``````using DiffEqFlux, DifferentialEquations, Plots

function fitzhugh_nagumo!(du, u, p, t)
v, w = u
a, b, c = p
du[1] = v * (a - v) * (v - 1) - w
du[2] = b * v - c * w
end

# ground truth data
u0 = [0.5, 0.0]
p_true = [-0.1, 0.01, 0.02]
tspan = (0.0, 200.0)
datasize = 100
tsteps = range(tspan[1], tspan[2], length=datasize)

prob_trueode = ODEProblem(fitzhugh_nagumo!, u0, tspan, p_true)
ode_data = Array(solve(prob_trueode, Rosenbrock23(), saveat=tsteps))
``````

``````dudt2 = FastChain(FastDense(2, 32, tanh),
FastDense(32, 2))

prob_neuralode = NeuralODE(dudt2, tspan, Rosenbrock23(), saveat=tsteps)

function loss_neuralode(p)
pred = Array(prob_neuralode(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss, pred
end

callback(p, l, pred; doplot=true) = begin
display(l)

# plot current prediction against data
plt = plot(tsteps, ode_data', label = "data")
scatter!(plt, tsteps, pred', label = "prediction")
if doplot
display(plot(plt))
end

return false
end

result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p,
cb=callback, maxiters=200)
``````

Did you try multiple shooting?

1:
Thank you, Chris! Multiple shooting works! (although most trials fell into local minima, I’m still trying to improve it.)

But even in successful trials, the training process still raised an error:
I’m not sure what is wrong, so I put all the error messages here.

``````1.6615035429439715  # loss from callback function
1.6608717651905747
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/initdt.jl:203
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/iOlmD/src/solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase ~/.julia/packages/SciMLBase/GW7GW/src/integrator_interface.jl:325
ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
convert(::Type{T}, ::Static.StaticFloat64{N}) where {N, T<:AbstractFloat} at ~/.julia/packages/Static/pkxBE/src/float.jl:26
convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/execution.jl:39
convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/core/value/constant.jl:103
...
Stacktrace:
[1] fill!(dest::Vector{Float32}, x::Nothing)
@ Base ./array.jl:351
[2] copyto!
[3] materialize!
[5] (::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
@ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/function/zygote.jl:8
[6] (::GalacticOptim.var"#144#152"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}}}, GalacticOptim.var"#143#151"{OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}}}, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}}, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}})(G::Vector{Float32}, θ::Vector{Float32})
@ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:93
[7] value_gradient!!(obj::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, x::Vector{Float32})
@ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:82
[8] value_gradient!(obj::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, x::Vector{Float32})
@ NLSolversBase ~/.julia/packages/NLSolversBase/cfJrN/src/interface.jl:69
[9] value_gradient!(obj::Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}}, x::Vector{Float32})
@ Optim ~/.julia/packages/Optim/wFOeG/src/Manifolds.jl:50
[10] (::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}}, Vector{Float32}, Vector{Float32}, Vector{Float32}})(α::Float32)
@ LineSearches ~/.julia/packages/LineSearches/Ki4c5/src/LineSearches.jl:84
[11] (::LineSearches.HagerZhang{Float64, Base.RefValue{Bool}})(ϕ::Function, ϕdϕ::LineSearches.var"#ϕdϕ#6"{Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}}, Vector{Float32}, Vector{Float32}, Vector{Float32}}, c::Float32, phi_0::Float32, dphi_0::Float32)
@ LineSearches ~/.julia/packages/LineSearches/Ki4c5/src/hagerzhang.jl:139
[12] HagerZhang
@ ~/.julia/packages/LineSearches/Ki4c5/src/hagerzhang.jl:101 [inlined]
[13] perform_linesearch!(state::Optim.BFGSState{Vector{Float32}, Matrix{Float32}, Float32, Vector{Float32}}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, d::Optim.ManifoldObjective{TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}})
@ Optim ~/.julia/packages/Optim/wFOeG/src/utilities/perform_linesearch.jl:59
[14] update_state!(d::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, state::Optim.BFGSState{Vector{Float32}, Matrix{Float32}, Float32, Vector{Float32}}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat})
@ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/solvers/first_order/bfgs.jl:139
[15] optimize(d::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, initial_x::Vector{Float32}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, options::Optim.Options{Float64, GalacticOptim.var"#_cb#150"{var"#17#19", BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}}}, state::Optim.BFGSState{Vector{Float32}, Matrix{Float32}, Float32, Vector{Float32}})
@ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/optimize/optimize.jl:54
[16] optimize(d::TwiceDifferentiable{Float32, Vector{Float32}, Matrix{Float32}, Vector{Float32}}, initial_x::Vector{Float32}, method::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, options::Optim.Options{Float64, GalacticOptim.var"#_cb#150"{var"#17#19", BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}}})
@ Optim ~/.julia/packages/Optim/wFOeG/src/multivariate/optimize/optimize.jl:36
[17] ___solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#271#281"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_multiple_shooting)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#276#286", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}}}, opt::BFGS{LineSearches.InitialStatic{Float64}, LineSearches.HagerZhang{Float64, Base.RefValue{Bool}}, Nothing, Float64, Flat}, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; cb::Function, maxiters::Nothing, maxtime::Nothing, abstol::Nothing, reltol::Nothing, progress::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:129
[18] #__solve#141
@ ~/.julia/packages/GalacticOptim/fow0r/src/solve/optim.jl:49 [inlined]
[19] #solve#480
@ ~/.julia/packages/SciMLBase/GW7GW/src/solve.jl:3 [inlined]
[20] sciml_train(::typeof(loss_multiple_shooting), ::Vector{Float32}, ::Nothing, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Nothing, kwargs::Base.Pairs{Symbol, var"#17#19", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#17#19"}}})
@ DiffEqFlux ~/.julia/packages/DiffEqFlux/UPHk9/src/train.jl:111
[21] top-level scope
@ ~/project/odenet/test4.jl:115
``````

2:
I wonder if the NeuralODE can be trained with different external inputs?
For example:

``````function fitzhugh_nagumo!(du, u, p, t)
v, w = u
a, b, c, input = p
du[1] = v * (a - v) * (v - 1) - w + input
du[2] = b * v - c * w
end

p1 = [-0.1, 0.01, 0.02, 0.0]
p2 = [-0.1, 0.01, 0.02, 0.1]
``````

Then I use p1 to generate ode_data1 and p2 to generate ode_data2 (with the same u0 and tspan).
Is it possible to apply these two data to train the network?

Sorry for not checking these pages first. The problem and solution are really well-documented.

Just for the record, I put links here.