Multithreading for nested loops

Hi,

I am wondering how to apply multi-threading to nested for loops.

For the function nestedloop below, I guess the threads will only be launched for the outermost k loop for OMP_JULIA_THREADS times:

function nestedloops(nx, ny, nz)

   state = ones(nx,ny,nz)

   Threads.@threads for k = 1:nz
      for j = 1:ny
         for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   return
end

If I add more @thread like the following:

function nestedloops2(nx, ny, nz)

   state = ones(nx,ny,nz)

   Threads.@threads for k = 1:nz
      Threads.@threads for j = 1:ny
         Threads.@threads for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   return
end

Will it launch OMP_JULIA_THREADS^3 threads in total? At least I can see an obvious decrease in performance and significant amount of additional memory allocations.

If I write the nested loops in a more compact way:

function nestedloops3(nx, ny, nz)

   state = ones(nx,ny,nz)

   Threads.@threads for k = 1:nz, j = 1:ny, for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
   end

   return
end

This will return error:

ERROR: LoadError: syntax: invalid assignment location "k = 1:nz"

Is there a way like OpenMP that we can collapse the nested loops and apply simd, i.e., something like #pragma for collapse(3) simd in C?

in 1.3+, you can nest @threades and Julia will figure out what to do

I don’t think that’s true. You can nest threads created with @spawn but @threads isn’t using the new partr scheduler AFAIU.

um, it’s a bit late for me to search examples properly but the language on https://julialang.org/blog/2019/07/multithreading/ says:

Some History
In version 0.5 about two years later, we released the @threads for macro with “experimental” status … however: @threads loops could not be nested: if the functions they called used @threads recursively, those inner loops would only occupy the CPU that called them. …

my understanding (and memory) is that in 1.3+ you can?

3 Likes

well, I’m the

confusion online

So what I asked in the main thread is still currently not possible?

Just don’t use @threads but @spawn the tasks in the loops. You can also check out the PR I linked in which @Mason proposed a @mythreads macro which should give you what you want.

@threads seems to be deterministic in starting tasks on threads:

using .Threads

const N = nthreads()
const t1 = zeros(Int, N)

julia> for j in 1:5
           @sync @threads for i in 1:N
               t1[i] = threadid()
           end
           println(t1, " - ", Set(t1))
       end
[1, 2, 3, 4, 5, 6, 7, 8] - Set([7, 4, 2, 3, 5, 8, 6, 1])
[1, 2, 3, 4, 5, 6, 7, 8] - Set([7, 4, 2, 3, 5, 8, 6, 1])
[1, 2, 3, 4, 5, 6, 7, 8] - Set([7, 4, 2, 3, 5, 8, 6, 1])
[1, 2, 3, 4, 5, 6, 7, 8] - Set([7, 4, 2, 3, 5, 8, 6, 1])
[1, 2, 3, 4, 5, 6, 7, 8] - Set([7, 4, 2, 3, 5, 8, 6, 1])

vs @spawn

const s1 = zeros(Int, N)

julia> for j in 1:5
           @sync for i in 1:N
               Threads.@spawn s1[i] = threadid()
           end
           println(s1, " - ", Set(s1))
       end
[2, 7, 3, 8, 6, 4, 5, 1] - Set([7, 4, 2, 3, 8, 5, 6, 1])
[3, 4, 8, 7, 3, 8, 7, 4] - Set([7, 4, 3, 8])
[5, 2, 2, 5, 2, 5, 4, 4] - Set([4, 2, 5])
[8, 3, 7, 3, 7, 7, 3, 3] - Set([7, 3, 8])
[4, 4, 4, 7, 6, 3, 3, 2] - Set([7, 4, 2, 3, 6])

with @spawn not every thread always gets a task. Or am I doing something wrong here?

Maybe @spawn is better for just spawning lots of tasks to threads.

As pointed by @lungben

You need to add @sync for your @spawn versions, otherwise the timer just measures the time to start the tasks and does not wait until they finish.

Now I have added the @sync macros.

# export JULIA_NUM_THREADS=2

function nestedloops1(nx, ny, nz)

   state = ones(nx,ny,nz)

   for k = 1:nz
      for j = 1:ny
         for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   #println(state[2,2,2])

   return
end


function nestedloops2(nx, ny, nz)

   state = ones(nx,ny,nz)

   @inbounds for k = 1:nz
      @inbounds for j = 1:ny
         @inbounds for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   #println(state[2,2,2])

   return
end

function nestedloops3(nx, ny, nz)

   state = ones(nx,ny,nz)

   Threads.@threads for k = 1:nz
      for j = 1:ny
         for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   #println(state[2,2,2])

   return

end

function nestedloops4(nx, ny, nz)

   state = ones(nx,ny,nz)

   @sync Threads.@spawn for k = 1:nz
      for j = 1:ny
         for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   #println(state[2,2,2])

   return

end

function nestedloops5(nx, ny, nz)

   state = ones(nx,ny,nz)

   for k = 1:nz, j = 1:ny, i = 1:nx
      state[i,j,k] *= sin(i*j*k)
   end

   #println(state[2,2,2])

   return

end

function nestedloops6(nx, ny, nz)

   state = ones(nx,ny,nz)

   Threads.@threads for k = 1:nz
      Threads.@threads for j = 1:ny
         Threads.@threads for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   #println(state[2,2,2])

   return

end

function nestedloops7(nx, ny, nz)

   state = ones(nx,ny,nz)

   @sync Threads.@spawn for k = 1:nz
      @sync Threads.@spawn for j = 1:ny
         @sync Threads.@spawn for i = 1:nx
            state[i,j,k] *= sin(i*j*k)
         end
      end
   end

   #println(state[2,2,2])

   return

end

function nestedloops8(nx, ny, nz)

   state = ones(nx,ny,nz)

   @sync Threads.@spawn for k = 1:nz, j = 1:ny, i = 1:nx
      state[i,j,k] *= sin(i*j*k)
   end

   #println(state[2,2,2])

   return

end



##
nx, ny, nk = 200, 200, 200
nestedloops1(nx, ny, nk)
nestedloops2(nx, ny, nk)
nestedloops3(nx, ny, nk)
nestedloops4(nx, ny, nk)
nestedloops5(nx, ny, nk)
nestedloops6(nx, ny, nk)
nestedloops7(nx, ny, nk)
nestedloops8(nx, ny, nk)

println("Number of threads = ",Threads.nthreads())
println("base line:")
@time nestedloops1(nx, ny, nk)
println("explicit @inbound:")
@time nestedloops2(nx, ny, nk)
println("@threads on the outer loop:")
@time nestedloops3(nx, ny, nk)
println("@spawn on the outer loop:")
@time nestedloops4(nx, ny, nk)
println("nested loop:")
@time nestedloops5(nx, ny, nk)
println("@threads on the triple loops:")
@time nestedloops6(nx, ny, nk)
println("@spawn on the triple loops:")
@time nestedloops7(nx, ny, nk)
println("@spawn on the nested loops:")
@time nestedloops8(nx, ny, nk)

which gives

Number of threads = 2
base line:
  0.209827 seconds (6 allocations: 61.035 MiB)
explicit @inbound:
  0.215929 seconds (6 allocations: 61.035 MiB, 4.34% gc time)
@threads on the outer loop:
  0.131187 seconds (24 allocations: 61.037 MiB)
@spawn on the outer loop:
  0.236467 seconds (17 allocations: 61.036 MiB, 4.39% gc time)
nested loop:
  0.212745 seconds (6 allocations: 61.035 MiB)
@threads on the triple loops:
  0.313132 seconds (221.97 k allocations: 77.209 MiB, 3.14% gc time)
@spawn on the triple loops:
  0.352818 seconds (444.76 k allocations: 95.426 MiB)
@spawn on the nested loops:
  0.246418 seconds (17 allocations: 61.036 MiB, 5.74% gc time)

The only speedup I see from this result is the @threads on the outermost loop, which is kind of unexpected to me…

You need to add @sync for your @spawn versions, otherwise the timer just measures the time to start the tasks and does not wait until they finish.

Nested @threads works in Julia 1.3 and later in the sense it doesn’t crash (before, it sometimes did, IIRC). But parallelization only happens at the outer-most for loop. So, if you want load-balance across multiple levels of for loops, I don’t think @threads is a good option.

FWIW, I think we need a better tooling for threading in Julia. @spawn is too much of a foot-gun for high-level programming and @threads is too limited for nested case. FYI, I’m packaging up a high-level threading API as ThreadsX.jl which includes ThreadsX.foreach that supports (possibly nested) parallel loops. I requested registration a few days ago so hopefully this will be registered soon.

4 Likes

But my original timings are for the whole function scope, not just the loops. Does it still matter?

I am puzzled by the result that after I added @sync there’s literally no speedup. Also, if I say

function nestedloops8(nx, ny, nz)

   state = ones(nx,ny,nz)

   @sync for k = 1:nz, j = 1:ny, i = 1:nx
      Threads.@spawn state[i,j,k] *= sin(i*j*k)
   end

   #println(state[2,2,2])

   return

end

It will get stuck.

That clarifies some of my doubts. However, why is my @spawn version not working as expected?

Your nestedloops8 seems to be using @sync and @spawn correctly (though not very efficiently). It works for me with small inputs like nestedloops8(2, 3, 4). But, when ns are big, I think it’ll spawn too many tasks. Since @spawn has some overhead, you need to “chunk” the iteration space into reasonably large sub-regions for computations like this.

1 Like