I am working on a code that use DifferentialEquations (more precisely, just OrdinaryDiffEqLowOrderRK - thanks for splitting that) and I have done some work to parallelize the computation of the derivative on multiple nodes, with MPI, and that seems to work relatively well performance-wise.
The most computationally expensive parts of the derivative were already parallelized using Base.Threads.@threads
, and we saw that the computation of the derivative was scaling reasonably well with the core numbers.
The problem is that, at a certain level of parallelism, the time spent outside of the derivative computation becomes important. Indeed, storing the state in a StructArray, we get that on one node 23% of the time is spent outside of the derivative computation, and on 2 nodes it becomes 35% of the total, staying more or less constant - as expected in the case where that part is not parallelized with MPI (see this comment).
A study on the number of threads instead seems to hint that the part outside the derivative computation does not seem to take advantage of multiple threads. Or perhaps it is threaded, but the operations there are so memory bound that using multiple cpus does not improve the performance at all. It also looks like the performance of this part of the code scales linearly with the size of the state object (see the second part of this comment).
Given these assumptions, I have thought that representing the state of the solver with a distributed array type like a PencilArray (as mentioned here) should do the trick, allowing now the solver to work on a smaller array while also making sure results are correct since the integration steps are correctly performed across MPI ranks.
So finally, using PencilArrays should give us scalability, I thought, even if threads might not.
But this does not seem to be the case.
Using 4 nodes, for example, we get that the 30% of the time is spent outside of the derivative computation, while on a single node it is only ~15% (see this comment).
I admit I have not yet profiled the code - not because I have not tried, but because the results were relatively hard to read, and eventually started relying on manual instrumentation with TimerOutputs.jl, finding low-IQ methods more robust for the time being, but I imagine I will have to step up the game soon.
Is anybody aware of reasons why better scaling is not to be expected, or suggestions for alternative approaches?