I’d like to program a differential equation that runs till it reaches a particular state (e.g. one of the space variables is zero), and then train the some parameters that define the equation so the system reaches the terminal state in minimum time.
My intuition is that I could use
terminate! in a callback function to terminate the differential equation when the it reaches the desired state, and then use define a loss as
solution is the solution to the differential equation). But when I try to do that, the gradient comes back as zero, even when I can see with finite differencing that the gradient is not zero. And of course I can’t train parameters when the gradient is zero.
How can I train parameters based on
f(T) for some function
T is the time required for the system defined by a differential equation to reach some state?
Here’s the code for a toy example I’m working with:
using DifferentialEquations, ForwardDiff const ∇=ForwardDiff.gradient function terminate_affect!(integrator) terminate!(integrator) end function terminate_condition(u,t,integrator) min(u,u) end terminate_cb = ContinuousCallback(terminate_condition,terminate_affect!) function damped!(du,u,p,t) #Damped harmonic oscillator du = u du = -pu - pu end u0 = [5.,10.] #starting point tspan = (0.,10.) #time span function loss(p) temp_prob = ODEProblem(damped!, u0,tspan, p) temp_solution = solve(temp_prob, Tsit5(), p = p, callback = terminate_cb) return temp_solution.t[end] end ∇(loss,[2.,5.])
Which returns `[0.00,0.00]’ whereas a simple finite differencing method returns a finite gradient:
du1 = (loss([2.,5.]) - loss([2.00000001,5.]))/(0.00000001) du2 = (loss([2.,5.]) - loss([2.,5.00000001]))/(0.00000001) [du1, du2]
(When I try to use
Zygote.gradient I get an error message:
Zygote.gradient(loss,[2.,5.]) ERROR: BoundsError: attempt to access () at index 
with a super long stack trace. But I suppose that’s a separate topic.)
So again, how can I use an auto differentiation system to calculate the derivative of a function of the time of a terminal callback, with respect to parameters that affect the time of the termination?