RK4 stages interpolation

Hello,

I have an ODE which I solve using the RK4 time integration with a resolution of dt. Now, I want to do interpolation of my generated trajectory (solution) at intermediate points so that the time step becomes dt/2. I know that I can use the available information from RK4 algorithm to make a 3rd order interpolation for the intermediate points. However, I also need to have interpolation of the RK4 stages (to be used later in discrete adjoint calculation). Is there a natural way to do that?

DifferentialEquations.jl will handle high-order interpolation for you: Solution Handling · DifferentialEquations.jl

Thanks for your prompt reply. I know that it is possible to do a high-order interpolation with the solution handling. However, this will be only possible for the time step itself (and not for the RK stages as well). The bottom line of my question is: if you have at the n th time step the RK4 stages written as {q1,q2,q3,q4}, and same thing for the point n+1. Is there a way to interpolate the RK4 stages to also get 4 stages at n+0.5?

I’m confused about what problem you are trying to solve. Why is interpolating the solution not sufficient for your application?

Does this have to do with our recent ICML paper?

If so, read Section 3.2 more closely:

Notice that … cannot be constructed directly from the z(tj ) trajectory of the ODE’s solution. More
precisely, the ki terms are not defined by the continuous ODE but instead by the chosen steps of the solver method. Continuous adjoint methods for neural ODEs (Chen et al., 2018; Zhuang et al., 2021) only define derivatives in terms of the ODE quantities. This is required in order exploit properties such as allowing different steps in reverse and reversibility for reduced memory, and in constructing solvers requiring fewer NFEs (Kidger et al., 2020). Indeed, computing the adjoint of each stage variable ki can be done, but is known as discrete sensitivity analysis and is known to be equivalent to automatic differentiation of the solver (Zhang & Sandu, 2014). Thus to calculate the derivative of the solution simultaneously to the derivatives of the solver states, we used direct automatic differentiation of the differential equation solvers for performing the experiments (Innes, 2018)

In other words, if the stages of the RK method are used in the loss function itself, then it is required that you are doing something at least equivalent to direct differentiation of the solver, which in Julia means you might as well do direct differentiation of the solver. Take a look at the code supplied for the paper to see how that was done. That’s currently the best form, though sensealg=ADPassThrough() is something you could use as well (sans Zygote/Diffractor issues with mutation).

1 Like