I’d like to program a differential equation that runs till it reaches a particular state (e.g. one of the space variables is zero), and then train the some parameters that define the equation so the system reaches the terminal state in minimum time.
My intuition is that I could use terminate!
in a callback function to terminate the differential equation when the it reaches the desired state, and then use define a loss as solution.t[end]
(where solution
is the solution to the differential equation). But when I try to do that, the gradient comes back as zero, even when I can see with finite differencing that the gradient is not zero. And of course I can’t train parameters when the gradient is zero.
How can I train parameters based on T
(or f(T)
for some function f
), where T
is the time required for the system defined by a differential equation to reach some state?
Here’s the code for a toy example I’m working with:
using DifferentialEquations, ForwardDiff
const ∇=ForwardDiff.gradient
function terminate_affect!(integrator)
terminate!(integrator)
end
function terminate_condition(u,t,integrator)
min(u[1],u[2])
end
terminate_cb = ContinuousCallback(terminate_condition,terminate_affect!)
function damped!(du,u,p,t)
#Damped harmonic oscillator
du[1] = u[2]
du[2] = -p[1]u[2] - p[2]u[1]
end
u0 = [5.,10.] #starting point
tspan = (0.,10.) #time span
function loss(p)
temp_prob = ODEProblem(damped!, u0,tspan, p)
temp_solution = solve(temp_prob, Tsit5(), p = p, callback = terminate_cb)
return temp_solution.t[end]
end
∇(loss,[2.,5.])
Which returns `[0.00,0.00]’ whereas a simple finite differencing method returns a finite gradient:
du1 = (loss([2.,5.]) - loss([2.00000001,5.]))/(0.00000001)
du2 = (loss([2.,5.]) - loss([2.,5.00000001]))/(0.00000001)
[du1, du2]
returns [0.0252..., 0.0363]
.
(When I try to use Zygote.gradient
I get an error message:
Zygote.gradient(loss,[2.,5.])
ERROR: BoundsError: attempt to access ()
at index [0]
with a super long stack trace. But I suppose that’s a separate topic.)
So again, how can I use an auto differentiation system to calculate the derivative of a function of the time of a terminal callback, with respect to parameters that affect the time of the termination?