DifferentialEquations.jl+MPI.jl+PencilArrays.jl: Lack of scaling observed

I am working on a code that use DifferentialEquations (more precisely, just OrdinaryDiffEqLowOrderRK - thanks for splitting that) and I have done some work to parallelize the computation of the derivative on multiple nodes, with MPI, and that seems to work relatively well performance-wise.
The most computationally expensive parts of the derivative were already parallelized using Base.Threads.@threads, and we saw that the computation of the derivative was scaling reasonably well with the core numbers.

The problem is that, at a certain level of parallelism, the time spent outside of the derivative computation becomes important. Indeed, storing the state in a StructArray, we get that on one node 23% of the time is spent outside of the derivative computation, and on 2 nodes it becomes 35% of the total, staying more or less constant - as expected in the case where that part is not parallelized with MPI (see this comment).

A study on the number of threads instead seems to hint that the part outside the derivative computation does not seem to take advantage of multiple threads. Or perhaps it is threaded, but the operations there are so memory bound that using multiple cpus does not improve the performance at all. It also looks like the performance of this part of the code scales linearly with the size of the state object (see the second part of this comment).

Given these assumptions, I have thought that representing the state of the solver with a distributed array type like a PencilArray (as mentioned here) should do the trick, allowing now the solver to work on a smaller array while also making sure results are correct since the integration steps are correctly performed across MPI ranks.
So finally, using PencilArrays should give us scalability, I thought, even if threads might not.

But this does not seem to be the case.
Using 4 nodes, for example, we get that the 30% of the time is spent outside of the derivative computation, while on a single node it is only ~15% (see this comment).

I admit I have not yet profiled the code - not because I have not tried, but because the results were relatively hard to read, and eventually started relying on manual instrumentation with TimerOutputs.jl, finding low-IQ methods more robust for the time being, but I imagine I will have to step up the game soon.

Is anybody aware of reasons why better scaling is not to be expected, or suggestions for alternative approaches?

1 Like

It’s hard to say without more information.

I would recommend using a profiler that lets you visualize the parallelism. I like Nsight Systems, see Nsight Systems tips · NVTX.jl for some tips.

There are threading choices for the solver parts. But before looking at that, can you share a flamegraph?

Here is a dump for PProf:

Note: the code has run without MPI, on a node with 76 physical cores.

With TimerOutputs.jl, I can see that the computation of the derivative (inside of which Threads.@threads is used, most of the times) takes 89.1% of the total time of the solve call (100% = 158 seconds). So the missing 10%, I suppose, is spent doing whatever else the solver needs to do.

Here is a picture of the whole flamegraph:


Some considerations/questions:

  • I recognize the the yellowish bars and whatever is under them as a part of our code
  • What is on top of it is the whole task machinery, right?
  • There is a very thin “flame” on the far left, if one expands it one can find the “solve” call there
  • It is my understanding that the length of the bars is proportional to the number of samples that are collected by all threads that are found executing that call path. So, that “pthread_cond_wait” could be partially happening when the derivative is being calculated, and partially when the solver does its things. Is this a correct interpretation?

To make things even simpler, I will post soon a flame graph with a single threaded version, in case it is useful.

1 Like