Issue with Callbacks When Solving Differential Equations Having Dual Number States

When I solve the same ODE with callback to stop simulation early and without callback, I get different partials at the same time not even close. I believe that adding callback introduces dual times and, at the event time, the partials of time become nonzero that cause jump in the partials of the states at the event time. For example, below is non-dimensional mass-spring in vertical axis. Check s1 to s6 for the jump. Is this expected or is there a bug?

using OrdinaryDiffEq, ForwardDiff, LinearAlgebra
#                     val    d/dy0  d/dvy0 d/dgy
y  = ForwardDiff.Dual(0.6,   1.0,   0.0,   0.0  )
vy = ForwardDiff.Dual(0.0,   0.0,   1.0,   0.0  )
gy = ForwardDiff.Dual(0.005, 0.0,   0.0,   1.0  )
u0 = [y;vy]
tspan = (0.0, 2.0)
function st!(du, u, p, t)
    @inbounds begin
        du[1] = u[2]
        du[2] = (1-u[1]) - p
    end
    return nothing
end
cnd(u, t, integ) = u[1] - 1
CC = ContinuousCallback(cnd, terminate!, affect_neg! = nothing, abstol=0, reltol=0)

prb = ODEProblem(st!, u0, tspan, gy)
sl1 = solve(prb, Feagin14(), abstol=1e-18, reltol=1e-18)
sl2 = solve(prb, Feagin14(), abstol=1e-18, reltol=1e-18, callback=CC)

s1 = sl1(sl2.t[end].value)
s2 = sl2[end]

s3 = sl1(sl2.t[end-1].value)
s4 = sl2[end-1]

s5 = sl1(sl2.t[end-2].value)
s6 = sl2[end-2]

s7 = sl1(sl2.t[end-3].value)
s8 = sl2[end-3]

s9 = sl1(sl2.t[end-4].value)
s10= sl2[end-4]
julia> s1 = sl1(sl2.t[end].value)
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.9999999997450032,-0.012658227202540293,0.9999198739688338,-1.0126582272025406)
 Dual{Nothing}(0.3949683502176894,-0.9999198739688338,-0.012658227202540293,-0.9999198739688332)

julia> s2 = sl2[end]
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(1.0,-2.382985700754451e-16,3.982129662242418e-15,2.3278389821175142e-15)
 Dual{Nothing}(0.39496832262015696,-1.0000799869548667,-4.797279698585894e-6,-1.0127344323578877)

julia> s3 = sl1(sl2.t[end-1].value)
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.9999999997450032,-0.012658227202540293,0.9999198739688338,-1.0126582272025406)
 Dual{Nothing}(0.3949683502176894,-0.9999198739688338,-0.012658227202540293,-0.9999198739688332)

julia> s4 = sl2[end-1]
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(1.0,-2.382985700754451e-16,3.982129662242418e-15,2.3278389821175142e-15)
 Dual{Nothing}(0.39496832262015696,-1.0000799869548667,-4.797279698585894e-6,-1.0127344323578877)

julia> s5 = sl1(sl2.t[end-2].value)
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.9748648857268446,0.05097497284343077,0.9986998325015783,-0.9490250271565693)
 Dual{Nothing}(0.3944864338381234,-0.9986998325015783,0.05097497284343077,-0.9986998325015778)

julia> s6 = sl2[end-2]
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.9748648842437063,0.05097497659821093,0.998699930790432,-0.9490250234017891)
 Dual{Nothing}(0.3944864726622205,-0.998699930790432,0.05097497659821093,-0.998699930790432)

From the documentation:

Unless otherwise specified, the OrdinaryDiffEq algorithms all come with a 3rd order Hermite polynomial interpolation. The algorithms denoted as having a “free” interpolation means that no extra steps are required for the interpolation. For the non-free higher order interpolating functions, the extra steps are computed lazily (i.e. not during the solve).

Feagin14 - Feagin’s 14th-order Runge-Kutta method.

Thus, what you were using was a 14th order method, but it only has a 3rd order interpolation. That method is great at steps but not a great method if you want to use the continuous solution. That is one of the reasons why it’s not the recommended method, and instead for most situations Vern9 is recommended:

Vern9 - Verner’s “Most Efficient” 9/8 Runge-Kutta method. (lazy 9th order interpolant)

which is 9th order with a 9th order interpolation. Try this to see it in action:

using OrdinaryDiffEq, ForwardDiff, LinearAlgebra
#                     val    d/dy0  d/dvy0 d/dgy
y  = ForwardDiff.Dual(0.6,   1.0,   0.0,   0.0  )
vy = ForwardDiff.Dual(0.0,   0.0,   1.0,   0.0  )
gy = ForwardDiff.Dual(0.005, 0.0,   0.0,   1.0  )
u0 = [y;vy]
tspan = (0.0, 2.0)
function st!(du, u, p, t)
    @inbounds begin
        du[1] = u[2]
        du[2] = (1-u[1]) - p
    end
    return nothing
end
cnd(u, t, integ) = u[1] - 1
CC = ContinuousCallback(cnd, terminate!, nothing, abstol=0, reltol=0)

prb = ODEProblem(st!, u0, tspan, gy)
sl1 = solve(prb, Vern9(), abstol=1e-12, reltol=1e-12)
sl2 = solve(prb, Vern9(), abstol=1e-12, reltol=1e-12, callback=CC)

s1 = sl1(sl2.t[end].value)
s2 = sl2[end]

s3 = sl1(sl2.t[end-1].value)
s4 = sl2[end-1]

s5 = sl1(sl2.t[end-2].value)
s6 = sl2[end-2]

s7 = sl1(sl2.t[end-3].value)
s8 = sl2[end-3]

s9 = sl1(sl2.t[end-4].value)
s10= sl2[end-4]

and the values are close to exact. Hopefully that helps!

1 Like

The problem is in the one with the callback(sl2), the interpolated values are close for t[end-2] to t[end-4] or any time step before for sl1 and sl2 but the jump happens in the sl2 that is not used with an interpolation at the event time. sl2.t[end-1](just before affect! is applied) has nonzero partials and the jump in the partials of the state values happens in sl2 not in the interpolated values from sl1. I actually tried different solvers(Tsit5, DP8, also Vern9) with different tolerances that all exhibit the jump in the states’ partials and time partials for callback added solution. Here, the result from Vern9 also

julia> s1 = sl1(sl2.t[end].value)
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.999999999999996,-0.01265822784809148,0.9999198814243908,-1.0126582278480916)
 Dual{Nothing}(0.3949683531626346,-0.9999198814243908,-0.01265822784809148,-0.9999198814243908)

julia> s2 = sl2[end]
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(1.0000000000000033,-4.04561141888263e-15,-1.8138508223614051e-13,3.6474487916570596e-13)
 Dual{Nothing}(0.39496835316263174,-1.0000801249951228,-1.579705560015433e-13,-1.0127393670834959)

julia> s3 = sl1(sl2.t[end-1].value)
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.999999999999996,-0.01265822784809148,0.9999198814243908,-1.0126582278480916)
 Dual{Nothing}(0.3949683531626346,-0.9999198814243908,-0.01265822784809148,-0.9999198814243908)

julia> s4 = sl2[end-1]
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(1.0000000000000033,-4.04561141888263e-15,-1.8138508223614051e-13,3.6474487916570596e-13)
 Dual{Nothing}(0.39496835316263174,-1.0000801249951228,-1.579705560015433e-13,-1.0127393670834959)

julia> s5 = sl1(sl2.t[end-2].value)
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.9369820218306023,0.14688095739087992,0.9891541762313583,-0.8531190426091201)
 Dual{Nothing}(0.39071589961138675,-0.9891541762313583,0.14688095739087992,-0.9891541762313583)

julia> s6 = sl2[end-2]
2-element Array{ForwardDiff.Dual{Nothing,Float64,3},1}:
 Dual{Nothing}(0.9369820218306022,0.14688095739088022,0.9891541762313582,-0.8531190426091197)
 Dual{Nothing}(0.3907158996113866,-0.9891541762313582,0.14688095739088022,-0.989154176231358)

See that sl2 has different partials in the event time t[end], t[end-1] before and after affect! from the rest of the time steps while sl1.t is bare float number without partials.

julia> sl2.t[end]
Dual{Nothing}(1.5834548927069032,0.03204871414821,-2.5316455696202356,2.563897131857077)

julia> sl2.t[end-1]
Dual{Nothing}(1.5834548927069032,0.03204871414821,-2.5316455696202356,2.563897131857077)

julia> sl2.t[end-2]
Dual{Nothing}(1.4233820399330215,0.0,0.0,0.0)

julia> sl2.t[end-3]
Dual{Nothing}(1.2370664663763726,0.0,0.0,0.0)

Also, is there a way to just use Float64 time stepping when there is a callback as in without callback?

I made a gist that shows my point that the partials of the states for solvers with callback have erroneous values at the event location. For that, I compared numerical differentiation results to the solution partials of dual-numbered ODE. @ChrisRackauckas Care to look?

Inspecting more closely, I conclude that it was not accuracy of interpolation but adding callback also cares about boundary constraint for the event; thus, giving rise to different partials at the event location than the one without unconstrained ODE. I wasn’t expecting this, that’s impressing :clap: :clap: :clap:.

Yes, sorry for not responding. Just got busy and digging into this kind of example is always fairly detailed. But yes, the reason why time is transformed into dual numbers is because you have to differentiate w.r.t. the boundary term whenever there is a continuous callback, since the solution to an ODE is an integral and so by integration by parts you get that extra term showing up at every event location. I plan to dig in still and make sure all of the values are correct, and probably add this to the set of tests we run. It’s just a very detailed piece of work :).

2 Likes

Indeed the values for the derivatives are correct. It’s actually quite difficult to make finite differencing work here, since termination causes a weird discontinuity that can throw it off. Here’s the check:

using OrdinaryDiffEq, ForwardDiff, LinearAlgebra, Test
#                     val    d/dy0  d/dvy0 d/dgy
y  = ForwardDiff.Dual(0.6,   1.0,   0.0,   0.0  )
vy = ForwardDiff.Dual(0.0,   0.0,   1.0,   0.0  )
gy = ForwardDiff.Dual(0.005, 0.0,   0.0,   1.0  )
u0 = [y;vy]
tspan = (0.0, 2.0)
function st!(du, u, p, t)
    @inbounds begin
        du[1] = u[2]
        du[2] = (1-u[1]) - p
    end
    return nothing
end
cnd(u, t, integ) = u[1] - 1
CC = ContinuousCallback(cnd, terminate!, nothing, abstol=0, reltol=0)

prb = ODEProblem(st!, u0, tspan, gy)
sl1 = solve(prb, Vern9(), abstol=1e-12, reltol=1e-12)
sl2 = solve(prb, Vern9(), abstol=1e-12, reltol=1e-12, callback=CC)

s1 = sl1(sl2.t[end].value)
s2 = sl2[end]
@test s1 ≈ s2

s3 = sl1(sl2.t[end-1].value)
s4 = sl2[end-1]
@test s3 ≈ s4

s5 = sl1(sl2.t[end-2].value)
s6 = sl2[end-2]
@test s5 ≈ s6

s7 = sl1(sl2.t[end-3].value)
s8 = sl2[end-3]
@test s7 ≈ s8

s9 = sl1(sl2.t[end-4].value)
s10= sl2[end-4]
@test s9 ≈ s10

x = [0.6,0.0,0.005]
using FiniteDiff
function get_endminusidx_cb(x;idx=0)
    y  = x[1]
    vy = x[2]
    gy = x[3]
    u0 = [y;vy]
    tspan = (0.0, 2.0)
    function st!(du, u, p, t)
        @inbounds begin
            du[1] = u[2]
            du[2] = (1-u[1]) - p
        end
        return nothing
    end
    cnd(u, t, integ) = u[1] - 1
    CC = ContinuousCallback(cnd, terminate!, nothing, abstol=0, reltol=0)

    prb = ODEProblem(st!, u0, tspan, gy)
    sol = solve(prb, Vern9(), abstol=1e-12, reltol=1e-12, callback = CC, saveat= ForwardDiff.value.(sl2.t))
    sol[end-idx]
end
function get_endminusidx(x;idx=0)
    y  = x[1]
    vy = x[2]
    gy = x[3]
    u0 = [y;vy]
    tspan = (0.0, 2.0)
    function st!(du, u, p, t)
        @inbounds begin
            du[1] = u[2]
            du[2] = (1-u[1]) - p
        end
        return nothing
    end
    cnd(u, t, integ) = u[1] - 1
    CC = ContinuousCallback(cnd, terminate!, nothing, abstol=0, reltol=0)

    prb = ODEProblem(st!, u0, tspan, gy)
    sol = solve(prb, Vern9(), abstol=1e-12, reltol=1e-12, saveat= ForwardDiff.value.(sl2.t))
    sol[end-idx]
end
# The first two need to use the callback, since that makes finite difference
# know that u[1] does not change at the end
enddiffs = FiniteDiff.finite_difference_jacobian(get_endminusidx_cb,x)
@test enddiffs ≈ reduce(hcat,Array.(ForwardDiff.partials.(s2)))' atol=1e-7

enddiffs = FiniteDiff.finite_difference_jacobian(x->get_endminusidx_cb(x,idx=1),x)
@test enddiffs ≈ reduce(hcat,Array.(ForwardDiff.partials.(s4)))' atol=1e-7

# These now use the no-callback version since that seems to be more stable
# for finite differencing
enddiffs = FiniteDiff.finite_difference_jacobian(x->get_endminusidx(x,idx=2),x)
@test enddiffs ≈ reduce(hcat,Array.(ForwardDiff.partials.(s6)))' atol=1e-5

enddiffs = FiniteDiff.finite_difference_jacobian(x->get_endminusidx(x,idx=3),x)
@test enddiffs ≈ reduce(hcat,Array.(ForwardDiff.partials.(s8)))' atol=1e-5

enddiffs = FiniteDiff.finite_difference_jacobian(x->get_endminusidx(x,idx=4),x)
@test enddiffs ≈ reduce(hcat,Array.(ForwardDiff.partials.(s10)))' atol=1e-5

This will be added as a test case.

1 Like