Joining forwards- and backwards-in-time ODE solutions

As in this question, I’m hoping to combine two ODE solutions, but I really need some of the nice features of the solutions — mostly dense output. And maybe in some ways my problem is easier because I really have the same ODEProblem, but just have to solve it both forwards and backwards in time.

So my question is primarily: How do I combine the two to get dense output? (Secondarily: are there any other features that might come in handy that I should try to retain?)

Schematically, I have something like this:

using DifferentialEquations, DiffEqBase

problem_forwards = ODEProblem(RHS!, uᵢ, (0, T))
solution_forwards = solve(problem_forwards, <...>)

problem_backwards = remake(problem_forwards; tspan=(0, -T))
solution_backwards = solve(problem_backwards, <...>)
solution_backwards = solution_backwards[end:-1:2]

(In that last line, I’ve just reversed the solution and dropped the 1 element because it’s identical to the 1 element of solution_forwards.)

My naive attempt to combine them looks like this:

alg = solution_forwards.alg
t = = [solution_backwards.t; solution_forwards.t]
u = [solution_backwards.u; solution_forwards.u]
k = [solution_backwards.k; solution_forwards.k]
retcode = solution_forwards.retcode  # Could be something more clever; maybe worse retcode?

problem = remake(solution_forwards.prob, tspan=(t[1], t[end]))

sol = DiffEqBase.build_solution(
    problem, alg, t, u,
    dense=true, k=k, retcode=retcode
)

But evidently it’s not enough to just copy u, t, and k, because sol.interp is a different type than solution_*.interp, and achieves significantly different results — presumably because it’s using the default LinearInterpolation(t, u). I can’t figure out how to construct a new interp to pass to this function. It felt a bit hacky, but I tried

interp = OrdinaryDiffEq.InterpolationData(
    solution_forwards.interp.f,
    u,
    t,
    k,
    solution_forwards.interp.dense,
    solution_forwards.interp.cache
)

Unfortunately, this gives incorrect results for times in the solution_backwards regime. (Also, for at least some choices of time stepper, I would apparently need to construct a CompositeInterpolationData.)

I’m not sure how much this matters, but it also seems that at least for some time steppers, alg is not the same for forwards and backwards integration. Specifically, if I use AutoVern9(Rodas5()), there’s this tiny difference in the alg that may just be some unique tag.

Is there some way to do this joining more correctly and/or elegantly?

Second attempt. After some input from Chris Rackauckas, I realized I also have to reverse k:

using OrdinaryDiffEq, DiffEqBase

function rhs!(du,u,p,t)
     du[1] = sin(u[2])
     du[2] = cos(u[1])
     du[3] = u[3]/100.0
end

uᵢ = [1.0;0.0;0.1]

problem_forwards = ODEProblem(rhs!, uᵢ, (0.0, 100.0))
solf = solve(problem_forwards, Tsit5())

problem_backwards = ODEProblem(rhs!, uᵢ, (0.0, -100.0))
solb = solve(problem_backwards, Tsit5())

alg = solf.alg
t = [reverse(solb.t); solf.t]
u = [reverse(solb.u); solf.u]
#k = [reverse(solb.k); solf.k]  # Both this one and the line below fail
k = [map(reverse, reverse(solb.k)); solf.k]

retcode = solf.retcode  # Could be something more clever; maybe the worse retcode?
interp = OrdinaryDiffEq.InterpolationData(solf.interp.f, u, t, k, solf.interp.dense, solf.interp.cache)
problem = ODEProblem(rhs!, uᵢ, (-100.0, 100.0))
sol = DiffEqBase.build_solution(problem, alg, t, u, dense=solf.dense, k=k, retcode=retcode, interp=interp)

This still doesn’t work for the first forwards time step, or any but the last backwards time step:

julia> sol(solb.t).u .≈ solb(solb.t).u
109-element BitVector:
 0
 0
 0
 0
 0
 ⋮
 0
 0
 0
 0
 0
 1

you want to reverse, not reverse reverse. And then I think you need to flip the sign. So map(-, reverse(solb.k)) is probably it.

Nope. Same result. I feel like I must be making some simple stupid mistake when I’m putting things together after that.

Actually thinking about it more, you could only flip k if the ODE method and its interpolation are symmetric. So I don’t think you can naively do this. It might be best to define a function which chooses which interpolation to call depending on the time point.

Makes sense. Probably easier anyway! Thanks for all your help. :slight_smile: