I am curious about the best approach for implementing custom ODE solvers in the DifferentialEquations.jl ecosystem.
As far as I know, there are mostly two options:
Contributing to OrdinaryDiffEq.jl:
When the solvers fit exactly into the OrdinaryDiffEq.jl framework then this is probably the best approach, since it requires less work while resulting in good performance and lots of features. Check out the SciML dev docs for the “adding new algorithms” guide. But this is only possible when things are not too custom - counterexamples are probabilistic ODE solvers (which is where I am coming from) or also time-parallel methods (e.g. Parareal.jl).
Building on DiffEqBase.jl:
With this approach, one has much more freedom and control, while being compatible with with the DifferentialEquations.jl ecosystem and getting some convenient functionality from DiffEqBase.jl. But, with this approach one also needs to implement much more functionality, such as the integrator interface, step-size control, etc.
Trying to get the best of both worlds, I explored a third option in ProbNumDiffEq.jl:
Building on OrdinaryDiffEq.jl:
This can be done by implementing only a few types and overloading some functions (check out my MyOrdinaryDiffEqSolver.jl repo for a minimal example), and as a result you get an integrator interface, step-size control, and other nice functionality such as efficient Jacobian calculations. It works well overall, but it does feel “hacky” and like I press a method into a framework that is not quite the right one for it. Also, doing this relies on much more than the public API, so each update needs to be tested carefully.
So my questions are:
When is which of these approaches appropriate?
Are there other advantages or disadvantages to these approaches that were not discussed yet?
Are there better approaches for implementing “non-standard” ODE solvers?
Note that StochasticDiffEq.jl and DelayDiffEq.jl do this as well. I’d write it like this:
If you just want the problem type and nothing else, great just go from SciMLBase.jl. I wouldn’t recommend this one for ODEs since DiffEqBase.jl has the callback handling, which is something that you probably wouldn’t want to have to copy. Pro: very little startup time. Con: no infrastructure. Good luck.
If you want all of the basic building blocks for the differential equation solver APIs (like the error handling functions, callbacks, etc.) then DiffEqBase.jl is what to build off of. Pro: enough infrastructure that you can easily guarantee that you follow the interface. Con: things that can be hard but require big dependencies (handling sparse Jacobians) are not here.
If you want to extend some of the numerical details of the solvers, say use the same Jacobian handling and all of that, then extending off of OrdinaryDiffEq.jl could work. But note that if you’re doing that, you’re likely extending non-public APIs, so you should let us know about this so that there’s more of a conversation about how things internally change (i.e. letting you know when we break something that’s being used downstream). We should probably add a downstream test there as well. I try to support everyone I can here, but without a downstream test it’s hard to keep people informed. Pro: more features come for free, like sparse differentiation. Con: you’d be building off of non-public internals. Things might break.
If you want your new solver to be maintained by the SciML community, add it to OrdinaryDiffEq.jl. This does require that it’s not too custom. “Too custom” is a bit hard of a line to draw: I think parareal could be done as a perform_step!, but it might be hard. Pro: you can be guaranteed that it will not break because it will all be part of the mainline OrdinaryDiffEq.jl development and tests. Con: it needs to fit the structure of OrdinaryDiffEq.jl, which is “most” solvers, but there are interesting cases that are outside of this.
As above, I think extending OrdinaryDiffEq.jl isn’t bad, but it needs to be done carefully. When we have everything in the SciML ecosystem, we have bots that tell us “this change will break DelayDiffEq”, so we make sure that PR is ready before we release an internal change to OrdinaryDiffEq.jl so that way the DDE solvers just seamlessly update. This all follows semvar: things that are not public API can break at any time, and the public API of OrdinaryDiffEq is solve, init, and the algorithms.
I don’t think that will change anytime soon just because some of these internals are pretty closely tied to how the solvers are written. These internals probably have 2-3 more years before every sparse GPU etc. case is handled well. But we as an org would be happy to take in some of these other repos and help maintain them to do more custom and weird things. It just needs real integration (downstream testing, co-development, and continued maintenance) instead of being treated as a one-and-done thing if it’s using the internals of another package.
Maybe in the future some of those bits will become a builder package or move to DiffEqBase. FiniteDiff.jl was an OrdinaryDiffEq.jl internal for many years, same with LinearSolve.jl, etc. so over time packages for tools get spawned out. But some pieces are still in too much flux to spawn out.
Thanks a lot for the helpful reply! The explanation for when to build on SciMLBase.jl vs DiffEqBase.jl vs contributing to OrdinaryDiffEq.jl seems in line with the expectations I had. But I was not aware of StochasticDiffEq.jl and DelayDiffEq.jl also building on top of OrdinaryDiffEq.jl! I will definitely have a look to see how things are done there, and I might adjust ProbNumDiffEq.jl accordingly (e.g. I saw that both of them have their own Integrator type whereas I just used the one from OrdinaryDiffEq.jl directly).
Overall, it seems that the approach I took so far is indeed reasonable, so I would continue to build ProbNumDiffEq.jl on OrdinaryDiffEq.jl and extend non-public APIs, even if that might sometimes lead to some minor issues. But since I started specifying upper bounds for the OrdinaryDiffEq compat I did not have many unexpected issues, since I check the diffs of each release and run my tests to check if something broke.
I would also be happy to communicate a bit more about which parts of OrdinaryDiffEq I exactly use, so that I could be a bit more in the loop for some changes and to better provide feedback. What would be the best way of doing that? Maybe a first step could be a section in the ProbNumDiffEq.jl documentation that lists exactly which part of OrdinaryDiffEq.jl has been built on? If there is a possibility for adding downstream tests that would of course also be great, but as I mentioned most unexpected issues are already prevented just by upper bounding the compat.