How do I debug this in diffeqflux?

This is my first attempt at using diffeq. I just don’t get much help from the stacktrace.
I’m trying to train a neural network that is inside a differential equation like a reinforcement learning problem.

I get a warning that says:

**┌ Warning: dt <= dtmin. Aborting. There is either an error in your model specification or the true solution is unstable.**
**└ @ DiffEqBase C:\Users\myUserID\.julia\packages\DiffEqBase\3iigH\src\integrator_interface.jl:343**

Then I get this (truncated):

ERROR: BoundsError: attempt to access 1-element Array{Float64,1} at index [0]
Stacktrace:
 [1] getindex at .\array.jl:809 [inlined]
 [2] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Array{Float64,1},Nothing,Base.OneTo{Int64},UnitRange{Int64},UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},DiffEqSensitivity.ZygoteVJP,Bool},Array{Float64,1},DiffEqBase.ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},DiffEqBase.ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},DiffEqBase.ODEFunction{true,typeof(Main.myFunc!),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:callback, :saveat),Tuple{DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.FlyPronav_vsTagBot.var"#1#2",typeof(DiffEqBase.terminate!),typeof(DiffEqBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Float64}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Vern9,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{true,typeof(Main.FlyPronav_vsTagBot.flyProNav!),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern9Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Vern9Tableau{Float64,Float64}}},DiffEqBase.DEStats},DiffEqSensitivity.CheckpointSolution{DiffEqBase.ODESolution{Float64,2,Array{Array{Float64,1},1},Nothing,Nothing,Array{Float64,1},Array{Array{Array{Float64,1},1},1},DiffEqBase.ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},DiffEqBase.ODEFunction{true,typeof(Main.myFunc!),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:callback, :saveat),Tuple{DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.FlyPronav_vsTagBot.var"#1#2",typeof(DiffEqBase.terminate!),typeof(DiffEqBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Float64}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Vern9,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{true,typeof(Main.FlyPronav_vsTagBot.flyProNav!),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float64,1},1},Array{Float64,1},Array{Array{Array{Float64,1},1},1},OrdinaryDiffEq.Vern9Cache{Array{Float64,1},Array{Float64,1},Array{Float64,1},OrdinaryDiffEq.Vern9Tableau{Float64,Float64}}},DiffEqBase.DEStats},Array{Tuple{Float64,Float64},1},NamedTuple{(:reltol, :abstol),Tuple{Float64,Float64}}},DiffEqBase.ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},DiffEqBase.ODEFunction{true,typeof(Main.FlyPronav_vsTagBot.flyProNav!),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:callback, :saveat),Tuple{DiffEqBase.CallbackSet{Tuple{DiffEqBase.ContinuousCallback{Main.FlyPronav_vsTagBot.var"#1#2",typeof(DiffEqBase.terminate!),typeof(DiffEqBase.terminate!),typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Float64}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{true,typeof(Main.myFunc!),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}})(::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}, ::Float64) at C:\Users\myUserID\.julia\packages\DiffEqSensitivity\WiCRA\src\local_sensitivity\interpolating_adjoint.jl:108

The training command looks like this:
res = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(0.1), cb = cb_plot, maxiters = 100)

Should I use something besides ADAM as an optimizer on this?

Is the ODE stable at the initial conditions? From what the warning is saying, it’s probably not.

It’s a good question. It was stable before a switched out the acceleration command from a pronav like acceleration input to the output of a neural network. It seems to solve with a randomly initialized neural network before I put it in the trainer. But then when diffeq is doing its zygote magic, perhaps the differential equation is no longer solvable. But there are so many ways to bake this cake, I haven’t stumbled on a clear example in the tutorials that I can use. I’m not even sure that ADAM is what I should be using as the optimizer.

Sometimes you can hit some randomly bad parameters. You can decrease the chance of this by decreasing the learning rate.

Is there a way I can have the loss kick it out to a max value when that happens? Should I put in a simple Euler as an integrator?

if sol.retcode != :Success

Thanks, I’ll try that. I did monkey with my termination criteria. Perhaps it is too strict and causing it to fail. I’m trying to simplify my problem right now to see if I can get something working.

Should I use a different termination condition for the solver?
cb_easy=CallbackSet(ContinuousCallback(groundhit_condition_msl,terminate!),ContinuousCallback(cpa_condition,terminate!))

I am also wondering if having a continuous neural network in the loop causes a problem. Should it be a discrete controller perhaps to stabilize the solver or does that cause more problems?

module Example_MLGuide

#using DifferentialEquations

using Plots

using LinearAlgebra

using DiffEqFlux

using Flux

using OrdinaryDiffEq

using DiffEqSensitivity

# Starting Reference for setting this up to integrate a neural network as the guidance law

#  https://julialang.org/blog/2019/01/fluxdiffeq/

# Using this example to help setup a neural network

#  https://github.com/SciML/DiffEqFlux.jl/blob/master/docs/src/examples/optimal_control.md

# Emulating PID Controller with Long Short-term Memory

# https://towardsdatascience.com/emulating-a-pid-controller-with-long-short-term-memory-part-2-4a37d32e5b47

# Video on Differential Programming 

#https://www.youtube.com/watch?v=LjWzgTPFu14&t=42s

# Optimal Controller

# https://diffeqflux.sciml.ai/dev/examples/optimal_control/

# Had problems with Linear Algebra toolbox versions having issues with training, did this to rule out that

my_norm(v) = sqrt(v[1]^2+v[2]^2+v[3]^2)

my_normalize(v) = (mag=my_norm(v); mag>0 ? v./mag : v)

saveDataAt = 0.1

mslSpeed=500.0

aimPosition = [0.0, 0.0, 0.0]

aimVelocity = [0.0, 0.0, 0.0]

startPosition = [550.0, 0.0, 5000.0]

startVelocity = my_normalize([-338.0, 0.0, -213.0])*mslSpeed

function calc_los_info(Pmsl,Paim,Vmsl,Vaim)

    R̄ = Paim-Pmsl

    V̄ᵣ = Vaim-Vmsl  #Relative Velocity

    r = my_norm(R̄) #Range

    ēlos = my_normalize(R̄)

    Vc  = dot(ēlos, V̄ᵣ)  # Velocity along the LOS vector

    Ω  = cross(R̄,V̄ᵣ) / dot(R̄,R̄)  # Rotational Vector of the Line of Sight

    V̄ₘ = Vmsl  # Missile Velocity Vector for ProNav

    return(; R̄, V̄ᵣ, r, ēlos, Vc, Ω,  V̄ₘ)

end

# Function to calculate the NNET inputs

function ann_input(LOSinfo)

    eVm = my_normalize(LOSinfo.V̄ₘ)

    Vm_mag = my_norm(eVm)

    Vr_aim = LOSinfo.V̄ᵣ

    R_aim = LOSinfo.R̄

    Vr_aim_mag = my_norm(Vr_aim)

    

    # Trying a kitchen sink type approach for example setup purposes

    [eVm[1], eVm[2], eVm[3],

    LOSinfo.r, 

    Vm_mag, 

    LOSinfo.Ω[1], LOSinfo.Ω[2], LOSinfo.Ω[3], 

    Vr_aim_mag, 

    R_aim[1], R_aim[2], R_aim[3]]

end

# Setup the Neural Network with the correct size

dummy=calc_los_info([0.0,0.0,0.0],[0.0,0.0,0.0],[1.0,2.0,3.0],[1.0,2.0,3.0])

annInputLen = length(ann_input(dummy))

ann = FastChain(FastDense(annInputLen,32,tanh), 

                    FastDense(32, 1, tanh))

θ = initial_params(ann)  # θ in example is the parameter space for the neural networks

function calc_accel(p, info)

    eVm = my_normalize(info.V̄ₘ)

    Vm_mag = my_norm(eVm)

    

    # Neural Network Response

    nnet_input = ann_input(info)

    accel_axial = ann(nnet_input,p)  # Perpendicular Acceleration

    println("accel_axial = $accel_axial")

    accel = cross(eVm,[0.0,1.0,0.0]) * accel_axial[1]  # Only apply acceleration in axial direction

    return accel

end

Gₙ=6.0 # Pronav Gain

maxAccel = 20*9.81

# Function for differential equation solver

function simpleFly!(du, u, p, t)

    Vmsl = u[1:3]

    Pmsl = u[4:6]

    Vaim = u[7:9]

    Paim = u[10:12]

    nnetθ = p

     

    info = calc_los_info(Pmsl, Paim, Vmsl, Vaim)

    

    accel_unlimited = calc_accel(nnetθ, info) * maxAccel  # Help neural network get up to speed

    #accel = limit_accel(accel_unlimited,maxAccel)

    accel = accel_unlimited

    

    du[1:3] = accel  # Acceleration of Msl  (derivative of Velocity)

    du[4:6] = Vmsl   # Velocity of Msl (derivative of Position)

    du[7:9] = [0.0,0.0,0.0]  # Acceleration of Aimpoint

    du[10:12] = Vaim  # Velocity of Aimpoint

end

# Initial Conditions

u0 = [startVelocity..., startPosition..., aimVelocity... , aimPosition...]

tspan=(0.0, 120.0)

groundhit_condition_msl = function(u,y,integrator)

    # Stay positive to keep going

    u[6]

end

cpa_condition = function(u,y,integrator)

    los = u[10:12]-u[4:6]

    eVmsl = my_normalize(u[1:3])

    elos = my_normalize(los)

    dist = my_norm(los)

    d=dot(elos,eVmsl)

    cone = acosd(abs(d)<1 ? d : sign(d))  # Acos protection

    120.0 - cone

end

tgt_miss_distance = function(u)

    miss = my_norm(u[10:12]-u[4:6])

end

#cb_full=CallbackSet(ContinuousCallback(cpa_condition,terminate!),ContinuousCallback(groundhit_condition_msl,terminate!),ContinuousCallback(lethal_miss,terminate!))

cb_easy=CallbackSet(ContinuousCallback(groundhit_condition_msl,terminate!),ContinuousCallback(cpa_condition,terminate!))

# Using this example to help me: https://github.com/SciML/DiffEqFlux.jl/blob/master/docs/src/examples/optimal_control.md

prob = ODEProblem(simpleFly!,u0,tspan, θ, callback=cb_easy, saveat=saveDataAt)

sol=solve(prob,Tsit5(),abstol=1e-10,reltol=1e-10)

@info "Done with first solve"

function plot_result(data)

    p=plot(data[4,:],data[6,:], fontsize = 10, label=("missile"))

    scatter!(p,data[10,:],data[12,:],label=("aim point"))

    return p

end

function predict_adjoint(θ)

    #s=solve(prob,Vern9(),p=p,sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))

    s=solve(prob,Tsit5(),p=θ,sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()))

end

const maxMiss= 999999

function loss_adjoint(θ)

    s=predict_adjoint(θ)

    println("retcode = $(s.retcode)")

    if s.retcode == :Terminated

        x = Array(s)

        if size(x,2)>1

            miss = tgt_miss_distance(x[:,end])

        else

            miss = maxMiss

        end

    else

        miss = maxMiss

    end

    miss = miss < maxMiss ? miss : maxMiss

    println("miss = $miss")

    return miss

    

end

@info "Calculating initial loss"

l = loss_adjoint(θ)

cb_plot = function (θ,l)

    println(l)

    s=predict_adjoint(θ)

    if s.retcode == :Success

        data = Array(s)

        p=plot_result(data)

        display(p)

    else

        println("Solution not successful: $(s.retcode)")

    end

    return false

end

#Display the ODE with the current parameter values

cb_plot(θ,l)

loss1 = loss_adjoint(θ)

@info "Starting training."

res = DiffEqFlux.sciml_train(loss_adjoint, θ, ADAM(0.0001), cb = cb_plot, maxiters = 100)

end

@ChrisRackauckas – I’m sure the problem is because I am a noob at the tools, but just in case there is a sinister bug in there, here is the code that blows chunks for me.

InterpolatingAdjoint is not compatible with callbacks right now. If you use:

s=solve(prob,Tsit5(),p=θ,sensealg=ReverseDiffAdjoint())

it’s fixed. I’m surprised InterpolatingAdjoint isn’t throwing an error sooner about the callback, but :man_shrugging: that should all be fixed soon by https://github.com/SciML/DiffEqSensitivity.jl/pull/350

Thank you so much. Since I’m rather blindly copying from examples and not really understanding all of the inner workings, this was a HUGE help. I’m not really sure what the difference is between InterpolatingAdjoing and all of the other ones are. Hopefully I’ll wade a path through the documentation and tutorials that will help me understand all of that soon.

I do still have an error though.
ERROR: type DiffEqArray has no field retcode

I think I stumbled through it. The solution doesn’t seem to be of the same type all of the time. Not sure I understand why, but I’m moving forward again.

Oh yes, thanks. That’s a minor issue fixed in make reversediffadjoint return a sensitivity solution by ChrisRackauckas · Pull Request #356 · SciML/SciMLSensitivity.jl · GitHub . I noticed it and forgot to mention that. You can just not look at the retcodes for now. I’ll merge that as a patch in a few hours.

So I think this my new problem is along those same lines, since the solve function returns an ODESolution type when not run in the train function, but returns a DiffEqArray type when run from the training function. I’m trying to grab end time from the solution as well as the states at the end. I think from the brief look at your issue, if I wait for your update it will fix that as well.

That PR got screwy so I fixed it up (https://github.com/SciML/DiffEqSensitivity.jl/pull/358) and this should merge.

I don’t know if this matters, but the increment_deriv! fails if I construct my loss function like:

    #if isa(s,RecursiveArrayTools.DiffEqArray) # Bug fix
        x=s.u[end]
        #@show x
        t=s.t[end]
        #@show t
    # else
    #     x = s[:,end]
    #     #@show x
    #     t=s.t[end]
    #     #@show t
    # end

I have to extract the data with the Array cast for the loss function to work:

        temp = Array(s)
        x=temp[:,end]
        #@show x
        t=s.t[end]

Is that code snippet enough for you to understand what I’m talking about?

Error looked like this:

ERROR: MethodError: no method matching increment_deriv!(::Array{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},2}, ::Base.RefValue{Any}, ::Int64)

Even more bizarre, is that temp=Array(s) works fine contributing to the loss function, until I use t in the loss function.
Then it gives me this error:

ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::Array{Float64,2})

Is it me?

I’m not quite sure I follow your construction. Can you give me a code I can copy/paste?