Bad performance using DistributedArray SPMD for 1D diffusion problem

Hi all,

I am writing my first parallel program in Julia 1.1 using the SPMD feature in the DistributedArray package. However, I am nowhere closer to the serial performance, and the parallel scaling is not good neither.

The problem is a simple 1D diffusion equation with Dirichlet boundary condition, 1 on the left and 0 on the right. It initially starts with all zeros.

Source code:

using Distributed, PyPlot
addprocs(1)
@everywhere using DistributedArrays, DistributedArrays.SPMD

@everywhere struct Param
   m::Int
   dx::Float64
   dx²::Float64
   Invd²::Float64
   dt::Float64
   α::Float64
   nStep::Int
   DoParallel::Bool
end

@everywhere struct Grid
   x::StepRangeLen
end

"Serial version"
function advance_serial!(f,param::Param)
   m, dx², Invd², dt, α = param.m, param.dx², param.Invd², param.dt, param.α

   fnew = copy(f)

   for i in 2:m
      fnew[i] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
   end

   f .= fnew
end


@everywhere begin
   "Parallel version"
   function advance!(f,param::Param)
      pids = sort(vec(procs(f)))

      m, dx², Invd², dt, α = param.m, param.dx², param.Invd², param.dt, param.α

      idx_ = localindices(f)[1]

      fnew = copy(f[:L])

      if nworkers() ≥ 2
         if 2 < myid() < nworkers()+1
            for (il,i) in enumerate(idx_)
               fnew[il] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
            end
         elseif myid() == 2
            fnew[1] = f[idx_[1]]
            for (il,i) in enumerate(idx_[1]+1:idx_[end])
               fnew[il+1] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
            end
         elseif myid() == nworkers()+1
            fnew[end] = f[idx_[end]]
            for (il,i) in enumerate(idx_[1]:idx_[end]-1)
               fnew[il] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
            end
         end
      else # only 1 worker
         fnew[1] = f[idx_[1]]
         fnew[end] = f[idx_[end]]
         for i in idx_[2]:idx_[end]-1
            fnew[i] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
         end
      end

      f_local = f[:L]
      f_local .= fnew

      barrier(;pids=pids)
   end
end

function setBCs!(f,f₀,param::Param)

end

function plotSol(grid::Grid,f,param::Param)
   x = grid.x
   figure()
   plot(x,f,linestyle="-",marker="o",label="FD")
   #xlim(0,40)
   xlabel("x")
   ylabel("Temp")
end

function setParameters(DoParallel::Bool=true)
   m    = 1000
   dx   = 1.0
   dx²  = dx^2
   Invd²= 1.0/dx²
   dt   = 0.20
   α    = 1.0
   nStep= 2000

   param = Param(m,dx,dx²,Invd²,dt,α,nStep,DoParallel)

   x = 0.0:dx:m*dx
   grid = Grid(x)

   return param,grid
end

function init(param::Param)
   m = param.m

   if param.DoParallel
      f₀ = DArray((m+1,1), workers(), [nworkers(),1]) do I
         fill(0.0,(map(length,I)...,))
      end
      fetch(@spawnat sort(workers())[1] localpart(f₀)[1] = 1.0)
   else
      f₀ = zeros(m+1)
      f₀[1] = 1.0
   end

   return f₀
end

"Main function for the parallel version using distributed arrays."
function main_parallel()

   param, grid = setParameters()
   f = init(param)

   iStep, time = 0, 0.0
   # Main loop
   runtime = @elapsed while iStep < param.nStep
      spmd(advance!, f, param, pids=workers())
      #setBCs!(f,param)
      time  += param.dt
      iStep += 1
   end

   println("Time elapsed parallel= $(runtime)s")

   plotSol(grid,f,param)

end

"Main function for the serial version."
function main_serial()

   param, grid = setParameters(false)
   f = init(param)

   iStep, time = 0, 0.0
   # Main loop
   runtime = @elapsed while iStep < param.nStep
      advance_serial!(f, param)
      #setBCs!(f,param)
      time  += param.dt
      iStep += 1
   end

   println("Time elapsed serial= $(runtime)s")

   plotSol(grid,f,param)

end

################
main_parallel() # warm up

main_parallel()
main_serial()

On my laptop the timings are:

Time elapsed parallel= 32.924770178s
Time elapsed parallel= 29.179283269s
Time elapsed serial= 0.091432367s

Note that in the parallel version, I am actually only using 1 proc for the above timings. My first guess for the bad performance is that calling SPMD for each timestep will cause significant overhead. Any tips or suggestions on how to optimize the code?

Distributed.jl is known to be pretty slow for workloads with significant communication, and you’re also doing a lot of redundant work in the inner loop. For a 1D diffusion problem, you shouldn’t need to run on more than one machine, and SharedArrays are much more performant than DArrays in that case. If you need to go bigger, MPI.jl is a reliable fallback.

2 Likes

Following your advice, I tried the SharedArrays first. It is definitely much better than DistributedArrays, but still not even closer to the serial performance. Perhaps this is just a bad situation for parallelization.

Source code using SharedArrays (in 2 ways):

using Distributed, PyPlot
addprocs(1)
@everywhere using SharedArrays

@everywhere struct Param
   m::Int
   dx::Float64
   dx²::Float64
   Invd²::Float64
   dt::Float64
   α::Float64
   nStep::Int
   DoParallel::Bool
end

@everywhere struct Grid
   x::StepRangeLen
end

"Serial version"
function advance_serial!(f,param::Param)
   m, dx², Invd², dt, α = param.m, param.dx², param.Invd², param.dt, param.α

   fnew = copy(f)

   for i in 2:m
      fnew[i] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
   end

   f .= fnew
end


#Split the range in the second dimension.
@everywhere function myrange(q::SharedArray)
   idx = indexpids(q)
   if idx == 0 # This worker is not assigned a piece
      return 1:0, 1:0
   end
   nchunks = length(procs(q))
   splits = [round(Int, s) for s in range(0, stop=size(q,2), length=nchunks+1)]
   1:size(q,1), splits[idx]+1:splits[idx+1]
end

@everywhere begin
   "Parallel version"
   function advance!(f,param::Param)
      m, dx², Invd², dt, α = param.m, param.dx², param.Invd², param.dt, param.α

      irange = localindices(f)
      fnew = copy(f[irange])

      if nworkers() ≥ 2
         if all(myid() .!= [2,nprocs()])
            for (il,i) in enumerate(irange)
               fnew[il] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
            end
         elseif myid() == 2
            for (il,i) in enumerate(irange[2:end])
               fnew[il+1] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
            end
         elseif myid() == nprocs()
            for (il,i) in enumerate(irange[1:end-1])
               fnew[il] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
            end
         end
      else
         for i in irange[2:end-1]
            fnew[i] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
         end
      end

      f[irange] .= fnew
   end
end

function setBCs!(f,f₀,param::Param)

end

function plotSol(grid::Grid,f,param::Param)
   x = grid.x
   figure()
   #fp = sdata(f)
   plot(x,f,linestyle="-",marker="o",label="FD")
   #xlim(0,40)
   xlabel("x")
   ylabel("Temp")
end

function setParameters(DoParallel::Bool=true)
   m    = 1000
   dx   = 1.0
   dx²  = dx^2
   Invd²= 1.0/dx²
   dt   = 0.20
   α    = 1.0
   nStep= 2000

   param = Param(m,dx,dx²,Invd²,dt,α,nStep,DoParallel)

   x = 0.0:dx:m*dx
   grid = Grid(x)

   return param,grid
end

function init(param::Param)
   m = param.m

   if param.DoParallel
      f₀ = SharedArray{Float64,1}(m+1)
      f₀[1] = 1.0
   else
      f₀ = zeros(m+1)
      f₀[1] = 1.0
   end

   return f₀
end

"Main function for the parallel version using distributed arrays."
function main_parallel()

   param, grid = setParameters()
   f = init(param)

   iStep, time = 0, 0.0
   # Method 1: calling functions in parallel
   #=
   runtime = @elapsed while iStep < param.nStep
      @sync begin
         for p in procs(f)
            @async begin
               remotecall_wait(advance!, p, f, param)
            end
         end
      end
      time  += param.dt
      iStep += 1
   end
   =#

   # Method 2: explicit parallel loop
   fnew = SharedArray{Float64,1}(param.m+1)
   fnew .= f

   m, dx², Invd², dt, α = param.m, param.dx², param.Invd², param.dt, param.α

   runtime = @elapsed while iStep < param.nStep
      @sync @distributed for i in 2:length(f)-1
         fnew[i] = f[i] + dt*α*(f[i+1] + f[i-1] - 2.0*f[i])*Invd²
      end
      f .= fnew
      time  += param.dt
      iStep += 1
   end

   println("Time elapsed parallel= $(runtime)s")

   plotSol(grid,f,param)

end

"Main function for the serial version."
function main_serial()

   param, grid = setParameters(false)
   f = init(param)

   iStep, time = 0, 0.0
   # Main loop
   runtime = @elapsed while iStep < param.nStep
      advance_serial!(f, param)
      #setBCs!(f,param)
      time  += param.dt
      iStep += 1
   end

   println("Time elapsed serial= $(runtime)s")

   plotSol(grid,f,param)

end

################
main_parallel() # warm up

main_parallel()
main_serial()

For comparison, I used the same input parameters, and again only 1 worker for the parallel version:

Time elapsed parallel= 1.569882835s
Time elapsed parallel= 0.847066288s
Time elapsed serial= 0.016118049s

Adding more workers may even decrease in performance.