DifferentialEquations.jl+MPI.jl+PencilArrays.jl: Lack of scaling observed

I am working on a code that use DifferentialEquations (more precisely, just OrdinaryDiffEqLowOrderRK - thanks for splitting that) and I have done some work to parallelize the computation of the derivative on multiple nodes, with MPI, and that seems to work relatively well performance-wise.
The most computationally expensive parts of the derivative were already parallelized using Base.Threads.@threads, and we saw that the computation of the derivative was scaling reasonably well with the core numbers.

The problem is that, at a certain level of parallelism, the time spent outside of the derivative computation becomes important. Indeed, storing the state in a StructArray, we get that on one node 23% of the time is spent outside of the derivative computation, and on 2 nodes it becomes 35% of the total, staying more or less constant - as expected in the case where that part is not parallelized with MPI (see this comment).

A study on the number of threads instead seems to hint that the part outside the derivative computation does not seem to take advantage of multiple threads. Or perhaps it is threaded, but the operations there are so memory bound that using multiple cpus does not improve the performance at all. It also looks like the performance of this part of the code scales linearly with the size of the state object (see the second part of this comment).

Given these assumptions, I have thought that representing the state of the solver with a distributed array type like a PencilArray (as mentioned here) should do the trick, allowing now the solver to work on a smaller array while also making sure results are correct since the integration steps are correctly performed across MPI ranks.
So finally, using PencilArrays should give us scalability, I thought, even if threads might not.

But this does not seem to be the case.
Using 4 nodes, for example, we get that the 30% of the time is spent outside of the derivative computation, while on a single node it is only ~15% (see this comment).

I admit I have not yet profiled the code - not because I have not tried, but because the results were relatively hard to read, and eventually started relying on manual instrumentation with TimerOutputs.jl, finding low-IQ methods more robust for the time being, but I imagine I will have to step up the game soon.

Is anybody aware of reasons why better scaling is not to be expected, or suggestions for alternative approaches?

1 Like

It’s hard to say without more information.

I would recommend using a profiler that lets you visualize the parallelism. I like Nsight Systems, see Nsight Systems tips · NVTX.jl for some tips.

There are threading choices for the solver parts. But before looking at that, can you share a flamegraph?

Here is a dump for PProf:

I am able to visualize it using the pprof executable bundled with pprof_jll version 1.0.1 (PProf.jl version 3.1.3).

Note: the code has run without MPI, on a node with 76 physical cores.

With TimerOutputs.jl, I can see that the computation of the derivative (inside of which Threads.@threads Threads.@spawn is used, most of the times) takes 89.1% of the total time of the solve call (100% = 158 seconds). So the missing 10%, I suppose, is spent doing whatever else the solver needs to do.

Here is a picture of the whole flamegraph:


Some considerations/questions:

  • I recognize the the yellowish bars and whatever is under them as a part of our code
  • What is on top of it is the whole task machinery, right?
  • There is a very thin “flame” on the far left, if one expands it one can find the “solve” call there
  • It is my understanding that the length of the bars is proportional to the number of samples that are collected by all threads that are found executing that call path. So, that “pthread_cond_wait” could be partially happening when the derivative is being calculated, and partially when the solver does its things. Is this a correct interpretation?

To make things even simpler, I will post soon a flame graph with a single threaded version, in case it is useful.

Edited:

  • fixed name of linked file
  • added version of PProf that can be used to open linked file
  • while @threads is used in most places in the code, in the most computationally expensive places @spawn is actually used.
1 Like

Single threaded version, no MPI. Here the PProf dump: PMFRG-N40-1thread.profile.pb.gz - Google Drive

With TimerOutputs.jl, I can see the computation of the derivative (inside of which Treads.@threads is used, most of the times) takes 99.8% of the total time of the solve call (100% = 8892 seconds). In this case, the “solver overhead” is completely negligible.

Picture of the flamegraph:

The considerations and question of the previous post do apply here.

I agree I need to profile/trace the execution in parallel - I will do it in the next days

yes so it looks like your solve time is basically zero and all of the time is in addX!, so the lack of scaling is outside of the solve and you need to figure out why addX! doesn’t scale.

But that’s in direct contrast to what time measurements say. AddX! is just a part of the derivative computation and the times measured with @timeit seem to scale reasonably - although not perfectly - with the number of threads.

Did you accidentally measure compile time one of the times?

Is this computation asynchronous, and when you’re measuring did you force it to appropriately sync in the timing? Since otherwise it would only measure the time to spawn the kernel but not completely execute.

I run a smaller example before profiling and measuring times to make sure compilation has already happened (see code below), then reset the timer and initialize the profile.

The computation of the derivative uses, down the call tree, just Threads.@spawn to parallelize the most expensive loops, in which AddX! is called.

The times are measured with, for example

@timeit_debug getXBubble!(...) 

and the most important @sync / Threads.@spawn block is inside that getXBubble! function.

High-level view of the code, removing the uninteresting parts:

# Warm up
Par = Params(...) # a toy problem which differ from the  real one only in size

_ = SolveFRG(
    Par,
    MultiThreaded(), # not using MPI
    [...]
    method=DP5(),  # integration method
    CheckPointSteps=3,
);

# Warmup and compilation done, timing real problem now
Par = Params(...) # A bigger problem used for the benchmark

reset_timer!() # Resetting the timer from TimerOutputs.jl

Profile.init(n = 10^8, delay = 0.005) # to be able to save all the samples

@profile _ = SolveFRG(
    Par,
    MultiThreaded(), # still no mpi
    [...]
    method=DP5(),
    CheckPointSteps=3,
);

print_timer() # from TimerOutputs.jl 

pprof(web=false,out="profile.pb.gz") # save to file

Note: part of what I said before was incorrect. The most expensive part of the computation is not parallelized with @threads, but with @spawn.

In your profiling / timing, did you time after the sync to ensure the spawned threads had finished? Since if I am guessing what the difference is, it’s in your timing it’s just timing the time to spawn the threads in addX!, while in the solve it runs addX! to spawn the threads, then waits for them to finish, then does the few solver things. That is why then all of the time is in addX! when profiling the solver, since profiling the solver would naturally require that it also profiles the full time to compute since it will have to sync before the next broadcast.

Sorry for being slow: I do not use @async anywhere. The Threads.@spawn block is tightly enclosed in a @sync bloc, and @timeit is around that.

Trying to rephrase what I understand from your suggestion: should I not trust what @timeit tells me?

There are for sure some inefficiencies in the way that the computation of the derivative is performed - I do not see perfect scaling there with the number of threads. Also, the horrendous amount spent idling could not possibly explained only by solve being single-threaded.

Whatever is done by solve needs to be done on a state vector which is (N=40)^3 double precision values. It’s a tiny amount of data, I imagine that the solve does mostly daxpy operations, and I wonder if daxpy operations on a few hundreds KB can be multi-threaded efficiently on so many cores at all (they should, though?).
Computations in the derivative instead scale with N^4, so it’s quite easier to saturate single-core bottlenecks with it, perhaps?

DP5(thread = OrdinaryDiffEq.PolyesterThreads()) should do it, or DP5(thread = OrdinaryDiffEq.True()) depending on the type of threading you want. If you’re already using Base threading instead of Polyester then that should be sufficient.

But yes, the solver is just doing a few O(n) broadcasts.

Actually, DP5(thread=OrdinaryDiffEq.True()) is something we tried in the past, but then it did not make a difference. Today instead it did.
When we tried this the first time we were using an ArrayPartition to store the state, while now we have a PencilArray.

Can someone confirm or deny, as a cross check, that for ArrayPartitions the broacasts are not multithreaded? (and that instead they are for PencilArrays?)

Broadcasts in PencilArrays are generally not multithreaded, but I’d assume that FastBroadcast.jl is able to parallelise them by converting them into loops.

Note that PencilArrays.jl itself does nothing for you in terms of multithreading, as it was created with the idea of having one “thread” per MPI process. Hybrid MPI-threads parallelism should work, it’s just that you need to take care of the threads part yourself.

In the past I spent some time to try and understand if multithreading would work on PencilArrays through FastBroacast, BUT what we really wanted to achieve with PencilArrays was to have the solver logic scale with multiple MPI ranks, with the possibly naive idea that at least, if multithreading did not help with the broacasts in the solver, by throwing more nodes at the problem we would be still inefficient, but speed up that part anyway by a factor roughly equal to the number of nodes (since the state of the integrator is smaller by that factor).

The recent experiments suggests the following:

  • using DP5 with threading seems to almost completely remove the “solver overhead” (I mean the time I measure with TimerOutputs for the solve call but outside of the derivative calculation). For further optimizations we will have to look into the derivative computation, and it’s ok because that’s our code.
  • For larger problems we use VCABM instead, and for that solver the argument thread=OrdinaryDiffEq.True() does not work. Moreover, it also seems that using PencilArrays to try reduce the size of the problem for each MPI rank does not cause a reduction in the “solver overhead”, which seems to stay constant with the number of nodes we use. I am puzzled by this possibly only because I am utterly ignorant of what VCABM exactly does, except that (from the documentation):

[…] the VCABM method can be a good choice for high accuracy when the system of equations is very large (>1,000 ODEs?), the function calculation is very expensive […]

In the light of this discussion - which was very useful anyway, since at least now we can use efficiently DP5 with threads and on multiple processes - I think I need to rephrase/modify the original question:

  • Are there obvious reasons why, with VCABM, we do not see scaling in the “solve overhead” with multiple nodes when we use PencilArrays? I’d assume that even without using threading we would see a speedup just by the fact that the state in the integrator is smaller by a factor of nranks, but we see no speedup whatsoever (while there is a speedup in the derivative calculation).

I will try to produce a flamegraph for the VCABM case (everything I’ve posted before concerns DP5) and post it here, and try to trace the execution (I first need to learn how to use Nsight and set it up)

It would be pretty trivial to add. The change to the stepper is just to pass thread to the macro

Add it to the alg

and the cache

(though honestly, just using integrator.alg.thread in the perform_step is fine)

and boom it’s done. For VCABM you’d need to do that here:

1 Like

Chris, can this have something to do with Hermite interpolation of the solution?

I think DP5 has a free interpolant, but I am not sure about VCABM.
The code makes use of saving Callbacks, so if the interpolation overhead is singlethreaded, this could be related to the issue.

What part is the overhead? Can you show a flame graph?