Re the blog post: PSA: Thread-local state is no longer recommended

Disfavored in the blog post, due to issues with threadid, was:

using Base.Threads: nthreads, @threads, threadid

states = [some_initial_value for _ in 1:nthreads()]
@threads for x in some_data
    tid = threadid()
    old_val = states[tid]
    new_val = some_operator(old_val, f(x))
    states[tid] = new_val 
end     
do_something(states)

Proposed solution: enumerate the loop with the index as a program-specific task id:

using Base.Threads: @threads

states = [some_initial_value for _ in eachindex(x)]  # <-- changed
@threads for (id, x) in enumerate(some_data)  # <-- the key change is here
    old_val = states[id]
    new_val = some_operator(old_val, f(x))
    states[id] = new_val
end
do_something(states)

Isn’t that pretty much the same pattern as the one proposed in the other thread (the latter just being more explicit/verbose)?

x isn’t defined here?

Sorry, that could have been some_data, and I was not posting working code (the example I was quoting did not define some_data)… I will add below:

const some_data = [3, 4, 5]
states = [some_initial_value for _ in eachindex(some_data)]  # <-- changed
@Threads.threads for (id, x) in enumerate(some_data)  # <-- the key change is here
    old_val = states[id]
    new_val = some_operator(old_val, f(x))
    states[id] = new_val
end
do_something(states)

I use this pattern myself. As written, it doesn’t actually work though.

julia> result = zeros(10)
       Threads.@threads for (i, x) in enumerate(19:28)
           result[i] = x ÷ 2
       end
ERROR: TaskFailedException

    nested task error: MethodError: no method matching firstindex(::Base.Iterators.Enumerate{UnitRange{Int64}})

You need to do collect(enumerate(some_data)). I wish the implementation didn’t require this, but a single collect over a multithreaded loop is usually insignificant.

Unfortunately, the collect version is not always correct, e.g. for channels.

That is ok, but then you need the number of states to be the same as the number of elements in the some_data array.

More generally one wants a buffer of length nthreads, which is independent of the number of tasks. That is why in general the proposal is to partition the workload in chunks and then use the chunk index (instead of threadid()).

Sometimes the collect for is also very expensive, in terms of allocation/compute. Is there much of a chance of this requirement being removed in the future?

I think you can use pairs there, and you don’t need the collection. But you could also use eachindex just indexing the elements inside the loop.

1 Like