Is it possible to short-circuit a parallel for loop?

Consider the following code that mimics short-circuiting any():

function bad_any(n::Integer)
  a = rand(Bool, n)
  for i in a
    i && return true 
  end
  return false
end

Is it possible to parallelize the for loop inside the function so that it will short-circuit without having to go through n iterations?

Here’s code which does it, though I’d bet good money there’s a much much less verbose way.

I use slow_check to slow down boolean checking so timing differences are more obvious. I have 4 Julia threads on my computer.

# threadsafe way of slowing down a comparison
function slow_check(x, thresh, sleeptime)
	Libc.systemsleep(sleeptime)
	return x > thresh
end

function parallel_breaking_any(arr, thresh, sleeptime)
	inds = collect(eachindex(arr))
	n = length(inds)
	nthreads = Threads.nthreads()
	any_channel = Channel{Bool}(nthreads)
	thread_todos = Vector{Vector{eltype(inds)}}()
	for i in 1:nthreads
		# assign work to each thread
		ind_start = 1 + ceil(Int, (i-1) / nthreads * n)
		ind_end = ceil(Int, i / nthreads * n)
		push!(thread_todos, inds[ind_start:ind_end])
	end
	Threads.@threads for i in 1:nthreads
		while !isempty(thread_todos[i])
			# check if we should break because another thread finished
			if isready(any_channel)
				break
			end
			# otherwise should check an index
			ind_to_check = pop!(thread_todos[i])
			if slow_check(arr[ind_to_check], thresh, sleeptime)
				put!(any_channel, true)
				break
			end
		end
	end
	return isready(any_channel)
end

function simple_any(arr, thresh, sleeptime)
	for i in eachindex(arr)
		if slow_check(arr[i], thresh, sleeptime)
			return true
		end
	end
	return false
end

thresh = .95
sleeptime = .1
none_greater = .5 * thresh * rand(20)
last_greater = copy(none_greater); last_greater[end] = 1.0
first_greater = copy(none_greater); first_greater[1] = 1.0

using BenchmarkTools
println("Simple, none greater, should take 2 seconds")
println(@belapsed simple_any(none_greater, thresh, sleeptime))
println("Simple, first greater, should take .1 seconds")
println(@belapsed simple_any(first_greater, thresh, sleeptime))
println("Simple, last greater, should take 2 seconds")
println(@belapsed simple_any(last_greater, thresh, sleeptime))
println("Parallel, none greater, should take .5 seconds")
println(@belapsed parallel_breaking_any(none_greater, thresh, sleeptime))
println("Parallel, first greater, should take .5 seconds (pop! removes from end of Vectors)")
println(@belapsed parallel_breaking_any(first_greater, thresh, sleeptime))
println("Parallel, last greater, should take .1 seconds (pop! removes from end of Vectors)")
println(@belapsed parallel_breaking_any(last_greater, thresh, sleeptime))

Output:

julia> include("C:\\Users\\ejfie\\Desktop\\any.jl")
Simple, none greater, should take 2 seconds
2.003889675
Simple, first greater, should take .1 seconds
0.10001727
Simple, last greater, should take 2 seconds
2.003114248
Parallel, none greater, should take .5 seconds
0.501039888
Parallel, first greater, should take .5 seconds (pop! removes from end of Vectors)
0.509741427
Parallel, last greater, should take .1 seconds (pop! removes from end of Vectors)
0.100068566

Thanks. I was looking for a non-threaded example, though I failed to specify this. I don’t think it’s possible without some sort of atomic reference and batching strategy.

This question was also asked once on Stack Overflow. My answer was a bit hard to follow, but I remember it was working. Link: Julia @parallel for loop with return statement - Stack Overflow .
Maybe you can make heads and tails of it.

1 Like