Does this have to do with our recent ICML paper?
If so, read Section 3.2 more closely:
Notice that … cannot be constructed directly from the z(tj ) trajectory of the ODE’s solution. More
precisely, the ki terms are not defined by the continuous ODE but instead by the chosen steps of the solver method. Continuous adjoint methods for neural ODEs (Chen et al., 2018; Zhuang et al., 2021) only define derivatives in terms of the ODE quantities. This is required in order exploit properties such as allowing different steps in reverse and reversibility for reduced memory, and in constructing solvers requiring fewer NFEs (Kidger et al., 2020). Indeed, computing the adjoint of each stage variable ki can be done, but is known as discrete sensitivity analysis and is known to be equivalent to automatic differentiation of the solver (Zhang & Sandu, 2014). Thus to calculate the derivative of the solution simultaneously to the derivatives of the solver states, we used direct automatic differentiation of the differential equation solvers for performing the experiments (Innes, 2018)
In other words, if the stages of the RK method are used in the loss function itself, then it is required that you are doing something at least equivalent to direct differentiation of the solver, which in Julia means you might as well do direct differentiation of the solver. Take a look at the code supplied for the paper to see how that was done. That’s currently the best form, though
sensealg=ADPassThrough() is something you could use as well (sans Zygote/Diffractor issues with mutation).