I’m using DIfferentialEquations.jl to solve a very large (10^9) system of coupled ODEs. I’m using MPI because the memory required for the full system won’t find on a workstation. The problem fits pretty naturally with the integrator interface and everything works well until the problem starts to get stiffer.
In general the MPI workers don’t automatically share the same dt so they tend to get out of step with each other. This can be handled by manually setting the dt outside of the step!() function. A problem seems to come up when some but not all of the MPI workers take a successful step. When this happens the workers that failed the success check circle back to take a smaller step while the workers that think the step was successful move on with the code. Once the workers are out of step with each other the code just hangs forever.
I’ve put together a fairly minimal working example to demonstrate the problem. The problem is two coupled oscillators with one MPI process for each oscillator. I wanted to make sure that the test problem was small enough that it would work on any computer with at least two cores. The serial version runs for 1000 steps with no problem. The MPI version goes about 700 steps and then hangs. From the terminal output I can see one worker reporting that it is done with the step!() function while the other is indicating that it is calling the mpi_coupled_oscillators!() function again which suggests that it is trying again instead of exiting the step!() function.
So I think the problem is that the MPI workers and tracking error/time steps/success independently but I’m not sure where the best place to try to change that is. Any and all ideas are greatly appreciated!
using DifferentialEquations
p = (1,1e8,1e-1) # (spring constant for oscillator 1, spring constant for oscillator 2, coupling term)
# test system done in serial to make sure the problem is solvable using a given ODE solver.
function coupled_oscillators!(du,u,p,t)
du[1] = u[2]
du[2] = -(p[1] + p[3]) * u[1] + p[3]*u[3]
du[3] = u[4]
du[4] = -(p[2] + p[3]) * u[3] + p[3]*u[1]
end
# set up the problem
u0 = [1.0; 0.0; -0.5; -0.5]
tspan = (0.0,100.0)
prob = ODEProblem(coupled_oscillators!,u0,tspan,p)
# define single step integrator
integrator = init(prob,Tsit5(), save_everystep = false, abstol = 1e-8, reltol = 1e-8)
# record that results as a text file since MPI precludes running interactively as far as I can tell.
# creates a text file with 5 columns (time, x1, v1, x2, v2)
io_all = open("all.txt","w")
# run the problem for some 1000 steps and record values after every step.
for i = 1:1000
step!(integrator) #take one step the jump out to record data
println(io_all, string(integrator.t)*" "*
string(integrator.u[1])*" "*
string(integrator.u[2])*" "*
string(integrator.u[3])*" "*
string(integrator.u[4])
)
end
####### mpi version
using MPI
MPI.Init()
using DifferentialEquations
const comm = MPI.COMM_WORLD
const mpi_size = MPI.Comm_size(comm)
const mpi_rank = MPI.Comm_rank(comm)
function mpi_coupled_oscillators!(du,u,p,t)
println("worker "*string(mpi_rank)*" dt is "* string(integrator.dt))#Tells you when each worker is inside the oscillator function and what dt is being used.
# have the two workers exchange information
if mpi_rank == 0
MPI.Isend(u, 1, 1, comm)
end
if mpi_rank == 1
MPI.Isend(u, 0, 2, comm)
end
MPI.Barrier(comm)
u_other = similar(u) #allocate space for input from the other worker
# get the input from the other worker
if mpi_rank == 0
MPI.Irecv!(u_other, 1, 2, comm)
end
if mpi_rank == 1
MPI.Irecv!(u_other, 0, 1, comm)
end
# update the du terms. MPI version of serial problem setup
if mpi_rank == 0
du[1] = u[2]
du[2] = -(p[1] + p[3]) * u[1] + p[3]*u_other[1]
end
if mpi_rank == 1
du[1] = u[2]
du[2] = -(p[2] + p[3]) * u[1] + p[3]*u_other[1]
end
end
# set up the problem with the same initial conditions as the serial version
if mpi_rank == 0
mpi_u0 = [1.0; 0.0]
end
if mpi_rank == 1
mpi_u0 = [-0.5; -0.5]
end
prob = ODEProblem(mpi_coupled_oscillators!,mpi_u0,tspan,p)
integrator = init(prob,Tsit5(), save_everystep = false, abstol = 1e-8, reltol = 1e-8) # parallel version of single step integrator
# for the mpi output files, each one has 3 columns (t1, x1, v1) and (t2, x2, v2)
io_out = open("part_"*string(mpi_rank)*".txt","w")
close(io_out)
# try to run for the same number of steps as the serial version.
for i = 1:1000
#try to force the solver to take the smaller of the two proposed steps to keep the processors aligned with eachother
integrator.dt = MPI.Allreduce!([integrator.dt], min, comm)[1]
integrator.dtpropose = MPI.Allreduce!([integrator.dtpropose], min, comm)[1]
step!(integrator) #take one step then jump out to record data
println("Step "*string(i)*" worker "*string(mpi_rank)*" ready ") #processors report when they have successfully taken a step and are ready to move on
io_out = open("part_"*string(mpi_rank)*".txt","a")
println(io_out, string(integrator.t)*" "*
string(integrator.u[1])*" "*
string(integrator.u[2])
)
close(io_out)
end