Often, we need to create some global variable whose value is different for each thread. This is often needed for storing some mathematical context variable (like precision, rounding, RNG, etc).
Usually, this code pattern (or similar) is used:
const SOMETHING = SomeType[]
get_something() = SOMETHING[Threads.threadid()]
function __init__()
resize!(SOMETHING, Threads.nthreads())
for i in eachindex(SOMETHING)
SOMETHING[i] = make_something() # or DEFAULT_SOMETHING
end
end
So I thought about a struct that encapsulates that code pattern and allows one to refer to the global variable as if it were a thread-specific Ref
.
struct ThreadRef{T} <: Ref{T}
data::Vector{T}
init_function::Union{Function,Nothing}
global function _make_thread_ref(init_function::Union{Function,Nothing}, ::Type{T}) where {T}
new{T}(T[], init_function)
end
end
function resize_thread_ref(ref::ThreadRef{T}) where {T}
nth = Threads.nthreads()
len = length(ref.data)
resize!(ref.data, nth)
if len < nth && !isnothing(ref.init_function)
for i ∈ len+1:nth
@inbounds ref.data[i] = ref.init_function()
end
end
end
function ThreadRef{T}(value) where {T}
_make_thread_ref(() -> (T.mutable ? copy(value) : value), T)
end
ThreadRef(value::T) where {T} = ThreadRef{T}(value)
ThreadRef{T}(; init_function = nothing) where {T} = _make_thread_ref(init_function, T)
function Base.isassigned(ref::ThreadRef)
id = Threads.threadid()
id <= length(ref.data) && isassigned(ref.data, id)
end
function Base.getindex(ref::ThreadRef)
id = Threads.threadid()
@boundscheck id <= length(ref.data) || resize_thread_ref(ref)
@inbounds(ref.data[id])
end
function Base.setindex!(ref::ThreadRef, value)
id = Threads.threadid()
@boundscheck id <= length(ref.data) || resize_thread_ref(ref)
@inbounds(ref.data[Threads.threadid()] = value)
end
Base.pointer_from_objref(ref::ThreadRef) = pointer(ref.data, Threads.threadid())
function Base.unsafe_convert(P::Union{Type{Ptr{T}},Type{Ptr{Cvoid}}}, ref::ThreadRef{T})::P where {T}
pointer_from_objref(ref)
end
Base.convert(::Type{ThreadRef{T}}, ref::ThreadRef{T}) where {T} = ref
Base.convert(::Type{ThreadRef{T}}, x) where {T} = ThreadRef{T}(x)
I noted that, with that struct, including resize in __init__
is optional - you could do that if you want to fasten your code using @inbounds
.
Example:
julia> MYVAR = ThreadRef(Int[])
ThreadRef{Vector{Int64}}(Vector{Int64}[], var"#5#6"{Vector{Int64}, Vector{Int64}}(Int64[]))
julia> MYVAR[]
Int64[]
julia> Threads.@threads for i in 1:20
push!(MYVAR[], rand(Int))
@info Threads.threadid(), i, MYVAR[]
end
[ Info: (6, 11, [8723786613677663540])
[ Info: (3, 5, [-3170150571523688199])
[ Info: (6, 12, [8723786613677663540, -1854103652421286688])
[ Info: (3, 6, [-3170150571523688199, 411164385622613090])
[ Info: (4, 7, [1706644823846658429])
[ Info: (4, 8, [1706644823846658429, 8244448464870095285])
[ Info: (5, 9, [-4992811659773667411])
[ Info: (12, 20, [919013625160668116])
[ Info: (1, 1, [2169316288210359347])
[ Info: (5, 10, [-4992811659773667411, 3817662974353008845])
[ Info: (1, 2, [2169316288210359347, -4733662603361506600])
[ Info: (10, 18, [-3017544938063281322])
[ Info: (7, 13, [-4207430156587551747])
[ Info: (7, 14, [-4207430156587551747, 5405408733298037114])
[ Info: (8, 15, [-1438606420548203930])
[ Info: (8, 16, [-1438606420548203930, -7793123522309815574])
[ Info: (9, 17, [5871745011087148451])
[ Info: (11, 19, [-2976418037209735115])
[ Info: (2, 3, [2545187215397803953])
[ Info: (2, 4, [2545187215397803953, -1955610229054269401])
julia> MYVAR1 = ThreadRef(10)
ThreadRef{Int64}(Int64[], var"#11#12"{Int64, Int64}(10))
julia> MYVAR1[]
10
julia> MYVAR1[] = 20
20
julia> unsafe_load(pointer_from_objref(MYVAR1))
20
A similar structure TaskRef
could be made for tasks using task_local_storage
.
Thoughts? Comments?
[EDIT: corrected bug in isassigned
]