DiffEqFlux NaN Bug with Neural ODE: "BoundsError: attempt to access 1-element Vector{Float64} at index [2]"

I’m adapting some code from the following blog post: Differentiable Programming and Neural ODEs for Accelerating Model Based Reinforcement Learning and Optimal Control | by Paul Shen | The Startup | Medium.

There’s an odd bug I’ve run into when running specifically the following line:

result = DiffEqFlux.sciml_train(
  loss_neuralode,
  pinit,
  ADAM(0.05),
  cb = callback,
  maxiters = 1500,
)

The notebook can be found here: diffeqflux-bug/Julia Notebook.ipynb at main · mhr/diffeqflux-bug · GitHub.

The solver seems to produce NaNs at some point, according to the error message. The full stack trace is below:

┌ Warning: First function call produced NaNs. Exiting.
└ @ OrdinaryDiffEq C:\Users\mretc\.julia\packages\OrdinaryDiffEq\PIjOZ\src\initdt.jl:95
┌ Warning: Automatic dt set the starting dt as NaN, causing instability.
└ @ OrdinaryDiffEq C:\Users\mretc\.julia\packages\OrdinaryDiffEq\PIjOZ\src\solve.jl:510
┌ Warning: NaN dt detected. Likely a NaN value in the state, parameters, or derivative value caused this outcome.
└ @ SciMLBase C:\Users\mretc\.julia\packages\SciMLBase\cA7Re\src\integrator_interface.jl:325
BoundsError: attempt to access 1-element Vector{Float64} at index [2]

Stacktrace:
  [1] getindex
    @ .\array.jl:801 [inlined]
  [2] adjoint
    @ C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\lib\array.jl:31 [inlined]
  [3] _pullback
    @ C:\Users\mretc\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57 [inlined]
  [4] _pullback
    @ C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\tools\builtins.jl:15 [inlined]
  [5] _pullback(::Zygote.Context, ::typeof(Zygote.literal_getindex), ::Vector{Float64}, ::Val{2})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
  [6] _pullback
    @ .\In[9]:15 [inlined]
  [7] _pullback(ctx::Zygote.Context, f::typeof(format), args::ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float32}, ODEFunction{true, typeof(cartpole_controlled), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(cartpole_controlled), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}}, DiffEqBase.DEStats})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
  [8] _pullback
    @ .\In[9]:29 [inlined]
  [9] _pullback(ctx::Zygote.Context, f::typeof(loss_neuralode), args::Vector{Float32})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [10] _pullback
    @ C:\Users\mretc\.julia\packages\DiffEqFlux\GkMjX\src\train.jl:84 [inlined]
 [11] _pullback(::Zygote.Context, ::DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, ::Vector{Float32}, ::Nothing)
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [12] _apply
    @ .\boot.jl:804 [inlined]
 [13] adjoint
    @ C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\lib\lib.jl:200 [inlined]
 [14] _pullback
    @ C:\Users\mretc\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57 [inlined]
 [15] _pullback
    @ C:\Users\mretc\.julia\packages\SciMLBase\cA7Re\src\problems\basic_problems.jl:107 [inlined]
 [16] _pullback(::Zygote.Context, ::OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Vector{Float32}, ::Nothing)
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [17] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [18] adjoint
    @ C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\lib\lib.jl:200 [inlined]
 [19] adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Tuple{Vector{Float32}, Nothing}, ::Tuple{})
    @ Zygote .\none:0
 [20] _pullback(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Tuple{Vector{Float32}, Nothing}, ::Tuple{})
    @ Zygote C:\Users\mretc\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57
 [21] _pullback
    @ C:\Users\mretc\.julia\packages\GalacticOptim\bEh06\src\function\zygote.jl:6 [inlined]
 [22] _pullback(ctx::Zygote.Context, f::GalacticOptim.var"#229#239"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}, args::Vector{Float32})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [23] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core .\boot.jl:804
 [24] adjoint
    @ C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\lib\lib.jl:200 [inlined]
 [25] _pullback
    @ C:\Users\mretc\.julia\packages\ZygoteRules\OjfTt\src\adjoint.jl:57 [inlined]
 [26] _pullback
    @ C:\Users\mretc\.julia\packages\GalacticOptim\bEh06\src\function\zygote.jl:8 [inlined]
 [27] _pullback(ctx::Zygote.Context, f::GalacticOptim.var"#232#242"{Tuple{}, GalacticOptim.var"#229#239"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, args::Vector{Float32})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface2.jl:0
 [28] _pullback(f::Function, args::Vector{Float32})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:34
 [29] pullback(f::Function, args::Vector{Float32})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:40
 [30] gradient(f::Function, args::Vector{Float32})
    @ Zygote C:\Users\mretc\.julia\packages\Zygote\TaBlo\src\compiler\interface.jl:75
 [31] (::GalacticOptim.var"#230#240"{GalacticOptim.var"#229#239"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
    @ GalacticOptim C:\Users\mretc\.julia\packages\GalacticOptim\bEh06\src\function\zygote.jl:8
 [32] macro expansion
    @ C:\Users\mretc\.julia\packages\GalacticOptim\bEh06\src\solve\flux.jl:43 [inlined]
 [33] macro expansion
    @ C:\Users\mretc\.julia\packages\GalacticOptim\bEh06\src\solve\solve.jl:35 [inlined]
 [34] __solve(prob::OptimizationProblem{false, OptimizationFunction{false, GalacticOptim.AutoZygote, OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, GalacticOptim.var"#230#240"{GalacticOptim.var"#229#239"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#233#243"{GalacticOptim.var"#229#239"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#82#87"{typeof(loss_neuralode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}}, GalacticOptim.var"#238#248", Nothing, Nothing, Nothing}, Vector{Float32}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Iterators.Pairs{Symbol, var"#5#7", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#5#7"}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ GalacticOptim C:\Users\mretc\.julia\packages\GalacticOptim\bEh06\src\solve\flux.jl:41
 [35] #solve#474
    @ C:\Users\mretc\.julia\packages\SciMLBase\cA7Re\src\solve.jl:3 [inlined]
 [36] sciml_train(::typeof(loss_neuralode), ::Vector{Float32}, ::ADAM, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Iterators.Pairs{Symbol, var"#5#7", Tuple{Symbol}, NamedTuple{(:cb,), Tuple{var"#5#7"}}})
    @ DiffEqFlux C:\Users\mretc\.julia\packages\DiffEqFlux\GkMjX\src\train.jl:89
 [37] top-level scope
    @ .\timing.jl:210 [inlined]
 [38] top-level scope
    @ .\In[11]:0
 [39] eval
    @ .\boot.jl:360 [inlined]
 [40] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base .\loading.jl:1116

WARNING: both Flux and Iterators export "flatten"; uses of it in module DiffEqFlux must be qualified
WARNING: both Flux and Distributions export "params"; uses of it in module DiffEqFlux must be qualified

What’s extremely weird is that when I run that line once initially, I get that error message. If I run it one more time, all calls thereafter don’t have the error message. I can reproduce this behavior every single time I run the code. It’s not a huge issue in practice, because I can just rerun the code in the notebook, but I would love to know why this occurs, and how I can avoid it, particularly if I run the code outside the notebook as a module.

Thank you for your time.

Maybe it’s just the random initialization you got.

Yes, but how is it that this happens every time I run the notebook? If it were random, wouldn’t the bug only happen some of the time, not every time? And how do I manipulate the random seed?

Are you talking about the initialization of the neural network, or something else?

Not sure. Try to isolate it?