When training a neural network in place with a diffeq solver, does the neural network have to be inside the diffeq solver call which is then wrapped with the sciml_train? Or can I have the neural network be defined to generate parameters that are used in the diffeq solver so they exist outside the diffeq solver function call?
I’m just hitting roadblock after roadblock on every sensitivity I fail to debug. Zygote experiences try/catch failures even though I don’t think I have any, so I think they are inside the diffeq solver somewhere. ReverseDiff gets a trackedValue error on a variable that is part of the integrator composite type. So I’m thinking maybe this order of neural networks outside the diffeq solver is not supported.
Here is the first bit of the latest ReverseDiff error.
ERROR: ArgumentError: Converting an instance of ReverseDiff.TrackedReal{Float64, Float64, Nothing} to Float32 is not defined. Please use `ReverseDiff.value` instead.
Stacktrace:
[1] convert(#unused#::Type{Float32}, t::ReverseDiff.TrackedReal{Float64, Float64, Nothing})
@ ReverseDiff ~\.julia\packages\ReverseDiff\iHmB4\src\tracked.jl:261
[2] fastpow
@ ~\.julia\packages\DiffEqBase\rN9Px\src\fastpow.jl:92 [inlined]
[3] PI_stepsize_controller!
@ ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\controllers.jl:44 [inlined]
[4] stepsize_controller!
@ ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\controllers.jl:58 [inlined]
[5] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{OrdinaryDiffEq.Tsit5, true, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Nothing, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, SciMLBase.ODESolution{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, 2, Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, Nothing, Nothing, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}}, SciMLBase.ODEProblem{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Tuple{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, true, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, SciMLBase.ODEFunction{true, Main.MLSAGA.SAGA.var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#10#11", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}, DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#8#9", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}}, Tuple{}}, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#10#11", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}, DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#8#9", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}}, Tuple{}}}}}, SciMLBase.StandardODEProblem}, OrdinaryDiffEq.Tsit5, OrdinaryDiffEq.InterpolationData{SciMLBase.ODEFunction{true, Main.MLSAGA.SAGA.var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{Vector{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}}, OrdinaryDiffEq.Tsit5Cache{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, OrdinaryDiffEq.Tsit5ConstantCache{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}}, DiffEqBase.DEStats}, SciMLBase.ODEFunction{true, Main.MLSAGA.SAGA.var"#17#18", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, OrdinaryDiffEq.Tsit5ConstantCache{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, OrdinaryDiffEq.DEOptions{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, Float64, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#10#11", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}, DiffEqBase.ContinuousCallback{Main.MLSAGA.SAGA.var"#8#9", typeof(SciMLBase.terminate!), typeof(SciMLBase.terminate!), typeof(DiffEqBase.INITIALIZE_DEFAULT), typeof(DiffEqBase.FINALIZE_DEFAULT), Float64, Int64, Nothing, Int64}}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryMinHeap{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, DataStructures.BinaryMinHeap{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, Vector{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}, Nothing, OrdinaryDiffEq.DefaultInit})
@ OrdinaryDiffEq ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\integrator_utils.jl:188
[6] loopfooter!
@ ~\.julia\packages\OrdinaryDiffEq\5egkj\src\integrators\integrator_utils.jl:168 [inlined]