Derivative with terminate callback

I’d like to program a differential equation that runs till it reaches a particular state (e.g. one of the space variables is zero), and then train the some parameters that define the equation so the system reaches the terminal state in minimum time.

My intuition is that I could use terminate! in a callback function to terminate the differential equation when the it reaches the desired state, and then use define a loss as solution.t[end] (where solution is the solution to the differential equation). But when I try to do that, the gradient comes back as zero, even when I can see with finite differencing that the gradient is not zero. And of course I can’t train parameters when the gradient is zero.

How can I train parameters based on T (or f(T) for some function f), where T is the time required for the system defined by a differential equation to reach some state?

Here’s the code for a toy example I’m working with:

using DifferentialEquations, ForwardDiff
const ∇=ForwardDiff.gradient

function terminate_affect!(integrator)
    terminate!(integrator)
end

function terminate_condition(u,t,integrator)
    min(u[1],u[2])
end

terminate_cb = ContinuousCallback(terminate_condition,terminate_affect!)

function damped!(du,u,p,t)
    #Damped harmonic oscillator
    du[1] = u[2]
    du[2] = -p[1]u[2] - p[2]u[1]
end

u0 = [5.,10.] #starting point
tspan = (0.,10.) #time span

function loss(p)
    temp_prob = ODEProblem(damped!, u0,tspan, p)
    temp_solution = solve(temp_prob, Tsit5(), p = p, callback = terminate_cb)
    return temp_solution.t[end]
end

∇(loss,[2.,5.])

Which returns `[0.00,0.00]’ whereas a simple finite differencing method returns a finite gradient:

du1 = (loss([2.,5.]) - loss([2.00000001,5.]))/(0.00000001)
du2 = (loss([2.,5.]) - loss([2.,5.00000001]))/(0.00000001)
[du1, du2]

returns [0.0252..., 0.0363].

(When I try to use Zygote.gradient I get an error message:

Zygote.gradient(loss,[2.,5.])
ERROR: BoundsError: attempt to access ()
  at index [0]

with a super long stack trace. But I suppose that’s a separate topic.)

So again, how can I use an auto differentiation system to calculate the derivative of a function of the time of a terminal callback, with respect to parameters that affect the time of the termination?

Sorry, that was a small bug. Fixed by https://github.com/SciML/DiffEqBase.jl/pull/584 and it should be released in about 2 hours as DiffEqBase 6.47.1

I’m newish to Julia, so I’m probably missing something, but I can’t seem to find DiffEqBase 6.47.1

(@v1.5) pkg> add DiffEqBase@6.47.1
   Updating registry at `C:\Users\~\.julia\registries\General`
   Updating git-repo `https://github.com/JuliaRegistries/General.git`
  Resolving package versions...
ERROR: Unsatisfiable requirements detected for package DiffEqBase [2b5f629d]:
 DiffEqBase [2b5f629d] log:
 ├─possible versions are: [3.13.2-3.13.3, 4.0.0-4.0.1, 4.1.0, 4.2.0, 4.3.0-4.3.1, 4.4.0, 4.5.0, 4.6.0, 4.7.0, 4.8.0, 4.9.0, 4.10.0-4.10.1, 4.11.0-4.11.1, 4.12.0, 4.13.0, 4.14.0-4.14.1, 4.15.0, 4.16.0, 4.17.0, 4.18.0, 4.19.0, 4.20.0-4.20.3, 4.21.0, 4.21.2-4.21.3, 4.22.0-4.22.2, 4.23.0, 4.23.2-4.23.4, 4.24.0-4.24.3, 4.25.0-4.25.1, 4.26.0-4.26.3, 4.27.0-4.27.1, 4.28.0-4.28.1, 4.29.0-4.29.2, 4.30.0-4.30.2, 4.31.0-4.31.2, 4.32.0, 5.0.0-5.0.1, 5.1.0, 5.2.0-5.2.3, 5.3.0-5.3.2, 5.4.0-5.4.1, 5.5.0-5.5.2, 5.6.0-5.6.4, 5.7.0, 5.8.0-5.8.1, 5.9.0, 5.10.0-5.10.3, 5.11.0-5.11.1, 5.12.0, 5.13.0, 5.14.0-5.14.2, 5.15.0, 5.16.0-5.16.5, 5.17.0-5.17.1, 5.18.0, 5.19.0, 5.20.0-5.20.1, 6.0.0, 6.1.0, 6.2.0-6.2.4, 6.3.0-6.3.6, 6.4.0-6.4.2, 6.5.0-6.5.1, 6.6.0, 6.7.0, 6.8.0, 6.9.0-6.9.4, 6.10.0-6.10.2, 6.11.0, 6.12.0-6.12.5, 6.13.0-6.13.3, 6.14.0-6.14.2, 6.15.0-6.15.2, 6.16.0, 6.17.0-6.17.3, 6.18.0-6.18.1, 6.19.0, 6.20.0, 6.21.0-6.21.1, 6.22.0-6.22.2, 6.23.0, 6.24.0, 6.25.0-6.25.2, 6.26.0, 6.27.0, 6.28.0, 6.29.0-6.29.3, 6.30.0-6.30.4, 6.31.0-6.31.1, 6.32.0-6.32.2, 6.33.0-6.33.1, 6.34.0-6.34.3, 6.35.0-6.35.2, 6.36.0-6.36.4, 6.37.0, 6.38.0-6.38.4, 6.39.0-6.39.1, 6.40.0-6.40.9, 6.41.0-6.41.3, 6.42.0, 6.43.0-6.43.1, 6.44.0-6.44.3, 6.45.0-6.45.1, 6.46.0-6.46.1, 6.47.0] or uninstalled
 └─restricted to versions 6.47.1 by an explicit requirement — no versions left

Likewise, a google search for 6.47.1 just returns this post.

You need to ]up first to update and get the newest versions.

That didn’t work.

(@v1.5) pkg> up
   Updating registry at `C:\Users\~\.julia\registries\General`
   Updating git-repo `https://github.com/JuliaRegistries/General.git`
  Installed EarCut_jll ───────── v2.1.5+0
  Installed OrderedCollections ─ v1.3.1
  Installed GeometryBasics ───── v0.3.1
  Installed Plots ────────────── v1.6.4
Updating `C:\Users\~\.julia\environments\v1.5\Project.toml`
  [91a5bcdd] ↑ Plots v1.6.3 ⇒ v1.6.4
Updating `C:\Users\~\.julia\environments\v1.5\Manifest.toml`
  [5ae413db] + EarCut_jll v2.1.5+0
  [5c1252a2] ↑ GeometryBasics v0.2.15 ⇒ v0.3.1
  [bac558e1] ↑ OrderedCollections v1.3.0 ⇒ v1.3.1
  [91a5bcdd] ↑ Plots v1.6.3 ⇒ v1.6.4
   Building Plots → `C:\Users\~\.julia\packages\Plots\4EfKl\deps\build.log`

(@v1.5) pkg> update DiffEqBase
   Updating registry at `C:\Users\~\.julia\registries\General`
   Updating git-repo `https://github.com/JuliaRegistries/General.git`
No Changes to `C:\Users\~\.julia\environments\v1.5\Project.toml`
No Changes to `C:\Users\~\.julia\environments\v1.5\Manifest.toml`

(@v1.5) pkg> status DiffEqBase
Status `C:\Users\~\.julia\environments\v1.5\Project.toml`
  [2b5f629d] DiffEqBase v6.47.0

(@v1.5) pkg> add DiffEqBase@6.47.1
  Resolving package versions...
ERROR: Unsatisfiable requirements detected for package DiffEqBase [2b5f629d]:
 DiffEqBase [2b5f629d] log:
 ├─possible versions are: [3.13.2-3.13.3, 4.0.0-4.0.1, 4.1.0, 4.2.0, 4.3.0-4.3.1, 4.4.0, 4.5.0, 4.6.0, 4.7.0, 4.8.0, 4.9.0, 4.10.0-4.10.1, 4.11.0-4.11.1, 4.12.0, 4.13.0, 4.14.0-4.14.1, 4.15.0, 4.16.0, 4.17.0, 4.18.0, 4.19.0, 4.20.0-4.20.3, 4.21.0, 4.21.2-4.21.3, 4.22.0-4.22.2, 4.23.0, 4.23.2-4.23.4, 4.24.0-4.24.3, 4.25.0-4.25.1, 4.26.0-4.26.3, 4.27.0-4.27.1, 4.28.0-4.28.1, 4.29.0-4.29.2, 4.30.0-4.30.2, 4.31.0-4.31.2, 4.32.0, 5.0.0-5.0.1, 5.1.0, 5.2.0-5.2.3, 5.3.0-5.3.2, 5.4.0-5.4.1, 5.5.0-5.5.2, 5.6.0-5.6.4, 5.7.0, 5.8.0-5.8.1, 5.9.0, 5.10.0-5.10.3, 5.11.0-5.11.1, 5.12.0, 5.13.0, 5.14.0-5.14.2, 5.15.0, 5.16.0-5.16.5, 5.17.0-5.17.1, 5.18.0, 5.19.0, 5.20.0-5.20.1, 6.0.0, 6.1.0, 6.2.0-6.2.4, 6.3.0-6.3.6, 6.4.0-6.4.2, 6.5.0-6.5.1, 6.6.0, 6.7.0, 6.8.0, 6.9.0-6.9.4, 6.10.0-6.10.2, 6.11.0, 6.12.0-6.12.5, 6.13.0-6.13.3, 6.14.0-6.14.2, 6.15.0-6.15.2, 6.16.0, 6.17.0-6.17.3, 6.18.0-6.18.1, 6.19.0, 6.20.0, 6.21.0-6.21.1, 6.22.0-6.22.2, 6.23.0, 6.24.0, 6.25.0-6.25.2, 6.26.0, 6.27.0, 6.28.0, 6.29.0-6.29.3, 6.30.0-6.30.4, 6.31.0-6.31.1, 6.32.0-6.32.2, 6.33.0-6.33.1, 6.34.0-6.34.3, 6.35.0-6.35.2, 6.36.0-6.36.4, 6.37.0, 6.38.0-6.38.4, 6.39.0-6.39.1, 6.40.0-6.40.9, 6.41.0-6.41.3, 6.42.0, 6.43.0-6.43.1, 6.44.0-6.44.3, 6.45.0-6.45.1, 6.46.0-6.46.1, 6.47.0] or uninstalled
 └─restricted to versions 6.47.1 by an explicit requirement — no versions left

Anyway, thank you very much for your answers and time. I really appreciate it!

v6.47.1 doesn’t exist yet. I’ll go make it.

We’re moving to a different town now, and I won’t be able to look at it till next week. (I’d hoped to do more now, but I can’t.) But thank you for your time!