Having trouble parallelizing a system of ODEs

Allow me to preface this by saying I have limited programming experience beyond the relatively short time I’ve been using Julia, so if you see I’ve made some kind of coding faux pas, please point it out!

I have a non-linear system of ODEs that I’m trying to simulate. The initial conditions are pre-determined from a chosen Poincare section, so I thought this would be an ideal scenario for parallelization. However, I’m encountering problems which are probably due to my inexperience, but I seem to figure out from the documentation. There are multiple callbacks, and the means to generate the initial conditions are a little involved so I wrapped everything before the definition of the ensemble problem in an @everywhere macro:

using Distributed
addprocs(4)
@everywhere begin
using DifferentialEquations
function eom(du, u, p, t)
    μ, k = p
    x, y, xd, yd, Ev = u
    r1(u,p) = √((x+μ)^2+y^2)
    r2(u,p) = √((x-1+μ)^2+y^2)
    du[1] = dx = xd
    du[2] = dy = yd
    du[3] = dxd = (2*yd + x - k/r1(u,p)^2 * (xd-y) - ((1-μ)/r1(u,p)^3 * (x+1-μ) +
            μ/r2(u,p)^3 * (x-1+μ)))
    du[4] = dyd = (-2*xd + y - k/r1(u,p)^2 * (yd+x) - ((1-μ)/r1(u,p)^3 +
            μ/r2(u,p)^3)*y)
    du[5] = -k/r1(u,p)*(xd^2-xd*y+x*yd+yd^2)

          return nothing

end

function eom_jac(J,u,p,t)
  μ, k = p
  x, y, xd, yd = u
  r1(u,p) = √((x+μ)^2+y^2)
  r2(u,p) = √((x-1+μ)^2+y^2)
  J[1,:] = [0 0 1 0 0]
  J[2,:] = [0 0 0 1 0]
  J[3,1] = (1 - (1-μ)/r1(u,p)^3 + 3(x+μ)*(x-1+μ)*(1-μ)/r1(u,p)^5 - μ/r2(u,p)^3 +
   3μ*(x+1-μ)^2 / r2(u,p)^5 + 2k*(x+μ)*(xd-y)/r1(u,p)^4)
  J[3,2] = 3(x-1+μ)*(1-μ)*y/r1(u,p)^5 + 3(x+1-μ)*μ*y/r2(u,p)^5 + 2k*y*(xd-y)/r1(u,p)^4
  J[3,3] = -k/r1(u,p)^4
  J[3,4] = 2
  J[3,5] = 0
  J[4,1] = 3(1-μ)*(x+μ)*y/r1(u,p)^5 + 3*(x-1+μ)*μ*y/r2(u,p)^5 + 2k*(x+μ)*(xd-y)/r1(u,p)^4
  J[4,2] = (1 + 3(1-μ)*y^2/r1(u,p)^5 + 3μ*y^2/r2(u,p)^5 - (1-μ)/r1(u,p)^3 -
  μ/r2(u,p)^3 + 2k*y*(xd-y)/r1(u,p)^4)
  J[4,3] = -2 -k/r1(u,p)^4
  J[4,4] = 0
  J[4,5] = 0
  J[5,1] = 4k*(x+μ)/r1(u,p)^3 * (xd^2+yd^2-xd*y+x*yd)-k/r1(u,p)^2 * yd
  J[5,2] = 4k*y/r1(u,p)^3 * (xd^2+yd^2-xd*y+x*yd) + k/r1(u,p)^2 * xd
  J[5,3] = -k/r1(u,p)^2 * (2xd-y)
  J[5,4] = -k/r1(u,p)^2 * (2yd+x)
  J[5,5] = 0

  nothing
end
#
ff = ODEFunction(eom; jac = eom_jac)


p = [0.5,0.0]
#Primary locations
P1 = [-p[1];0]
P2 = [1-p[1];0]
#System Radii
R1 = 1e-4
R2 = 1e-4
Rsys = 10.0
tspan = (0.0,1000.0)

## Event Handling
import LinearAlgebra.norm

function condition(out,u,t,integrator)
  out[1] = sqrt(u[1]^2+u[2]^2)-Rsys #particle escapes
  out[2] = R1-norm(u[1:2]-P1) #particle collides with the sun
  out[3] = R2-norm(u[1:2]-P2) #particle collides with the planet
end

function affect!(integrator, event_index)
  terminate!(integrator)
end

function manifold(resid,u,p,t) #when dissipation is off, system is Hamiltonian
  resid[1] = 0.5*(u[3]^2+u[4]^2) + V(u[1],u[2],p[1]) - E
  resid[2] = 0
  resid[3] = 0
  resid[4] = 0
  resid[5] = 0
end

cbv = VectorContinuousCallback(condition,affect!,3)
cbm = ManifoldProjection(manifold)
cbs = CallbackSet(cbv,cbm)

## initial conditions
#generate an array of initial conditions from a Poincare section with
#E = -1.375, rdot = 0, phidot < 0

x = range(-2.0,2.0,length = 1000)
y = range(-2.0,2.0,length = 1000)
E = -1.375
r1(x,y,μ) = √((x+μ)^2+y^2)
r2(x,y,μ) = √((x-1+μ)^2+y^2)
R(x,y) = √(x^2 + y^2)
V(x,y,μ) = -μ/r1(x,y,μ) - (1-μ)/r2(x,y,μ) - 1/2 * (x^2+y^2) #Jacobi Potential
g(x,y,μ) = √(-2*V(x,y,μ)+2*E)
xd0(x,y,μ) = y/R(x,y)*g(x,y,μ) #x initial velocity
yd0(x,y,μ) = -x/R(x,y)*g(x,y,μ) #y initial velocity
xy = [[x[i] y[j] xd0(x[i],y[j],p[1]) yd0(x[i],y[j],p[1]) E] for i=1:length(x), j=1:length(y)]



## Parallelize!

prob = ODEProblem(ff,xy[1],tspan,p) #base problem
# solver_args = (abstol=1e-8,reltol=1e-8)
# @time sol = solve(prob,Vern7(),callback=cbs;solver_args...)

function prob_func(prob,i,repeat)
  ODEProblem(prob.f,xy[i],prob.tspan,prob.p)
end

end #end @everywhere
EnsProb = EnsembleProblem(prob,prob_func=prob_func)


sim = solve(EnsProb,Vern7(),EnsembleDistributed(),trajectories=length(xy), callback = cbs, abstol = 1e-8, reltol = 1e-8)

When I format the code to solve serially (i.e, constructed a new ODEProblem for each initial condition and solving it), the code runs fine, but it’s obviously a very unsatisfying way to go about it. When I format as an EnsembleProblem, the code either causes the computer to hang (99% CPU & MEM usage) or takes an extremely long time (longer than it would take single-threaded.) I’ve tried using both EnsembleThreads() and EnsembleDistributed() and both have similar issues.

What I’ve come to ask is: is there any obvious mistake I’m making here?

Additionally, it would be nice if it were possible to write this in a way that’s able to use @DiffEqGPU, but I’m encountering many errors when exploring this, and so any insights direction would also be appreciated.

You’re saving everything. Like everything. My guess is that you’re running out of memory. you might want to add saveat=0.1 or something like that to solve because if you’re hitting the memory too often then parallelism won’t help much. Try reducing the allocations like that first and see how well you do.

Note that defining functions like this inside of your dynamics won’t be compatible with DiffEqGPU. We’ll be updating it to KernelAbstractions soon, so it’ll throw better errors soon, but essentially you need to make your function look very much like the README, i.e.:

function lorenz(du,u,p,t)
 @inbounds begin
     du[1] = p[1]*(u[2]-u[1])
     du[2] = u[1]*(p[2]-u[3]) - u[2]
     du[3] = u[1]*u[2] - p[3]*u[3]
 end
 nothing
end

so you’ll want to do things like μ = p[1]; k = p[2] and stuff like that. It might be easiest to just modelingtoolkitize(prob) and regenerate the function as that will perform the translation for you.

Though one other thing: how necessary is the ManifoldProjection? That’s going to allocate quite a bit here since it uses NLsolve, and there’s a chance that is your rate-limiting step. It would be helpful if you could share a profile of the serial code to know what’s taking the most time. An overview of how to do such profiling can be found here

2 Likes