Structure that creates thread-specific global variables

Often, we need to create some global variable whose value is different for each thread. This is often needed for storing some mathematical context variable (like precision, rounding, RNG, etc).

Usually, this code pattern (or similar) is used:

const SOMETHING = SomeType[]

get_something() = SOMETHING[Threads.threadid()]

function __init__()
    resize!(SOMETHING, Threads.nthreads())
    for i in eachindex(SOMETHING)
        SOMETHING[i] = make_something() # or DEFAULT_SOMETHING
    end
end

So I thought about a struct that encapsulates that code pattern and allows one to refer to the global variable as if it were a thread-specific Ref.

struct ThreadRef{T} <: Ref{T}
    data::Vector{T}
    init_function::Union{Function,Nothing}

    global function _make_thread_ref(init_function::Union{Function,Nothing}, ::Type{T}) where {T}
        new{T}(T[], init_function)
    end
end

function resize_thread_ref(ref::ThreadRef{T}) where {T}
    nth = Threads.nthreads()
    len = length(ref.data)
    resize!(ref.data, nth)
    if len < nth && !isnothing(ref.init_function)
        for i ∈ len+1:nth
            @inbounds ref.data[i] = ref.init_function()
        end
    end
end

function ThreadRef{T}(value) where {T}
    _make_thread_ref(() -> (T.mutable ? copy(value) : value), T)
end

ThreadRef(value::T) where {T} = ThreadRef{T}(value)
ThreadRef{T}(; init_function = nothing) where {T} = _make_thread_ref(init_function, T)

function Base.isassigned(ref::ThreadRef)
    id = Threads.threadid()
    id <= length(ref.data) && isassigned(ref.data, id)
end

function Base.getindex(ref::ThreadRef)
    id = Threads.threadid()
    @boundscheck id <= length(ref.data) || resize_thread_ref(ref)
    @inbounds(ref.data[id])
end

function Base.setindex!(ref::ThreadRef, value)
    id = Threads.threadid()
    @boundscheck id <= length(ref.data) || resize_thread_ref(ref)
    @inbounds(ref.data[Threads.threadid()] = value)
end

Base.pointer_from_objref(ref::ThreadRef) = pointer(ref.data, Threads.threadid())

function Base.unsafe_convert(P::Union{Type{Ptr{T}},Type{Ptr{Cvoid}}}, ref::ThreadRef{T})::P where {T}
    pointer_from_objref(ref)
end

Base.convert(::Type{ThreadRef{T}}, ref::ThreadRef{T}) where {T} = ref
Base.convert(::Type{ThreadRef{T}}, x) where {T} = ThreadRef{T}(x)

I noted that, with that struct, including resize in __init__ is optional - you could do that if you want to fasten your code using @inbounds.

Example:

julia> MYVAR = ThreadRef(Int[])
ThreadRef{Vector{Int64}}(Vector{Int64}[], var"#5#6"{Vector{Int64}, Vector{Int64}}(Int64[]))

julia> MYVAR[]
Int64[]

julia> Threads.@threads for i in 1:20
       push!(MYVAR[], rand(Int))
       @info Threads.threadid(), i, MYVAR[]
       end
[ Info: (6, 11, [8723786613677663540])
[ Info: (3, 5, [-3170150571523688199])
[ Info: (6, 12, [8723786613677663540, -1854103652421286688])
[ Info: (3, 6, [-3170150571523688199, 411164385622613090])
[ Info: (4, 7, [1706644823846658429])
[ Info: (4, 8, [1706644823846658429, 8244448464870095285])
[ Info: (5, 9, [-4992811659773667411])
[ Info: (12, 20, [919013625160668116])
[ Info: (1, 1, [2169316288210359347])
[ Info: (5, 10, [-4992811659773667411, 3817662974353008845])
[ Info: (1, 2, [2169316288210359347, -4733662603361506600])
[ Info: (10, 18, [-3017544938063281322])
[ Info: (7, 13, [-4207430156587551747])
[ Info: (7, 14, [-4207430156587551747, 5405408733298037114])
[ Info: (8, 15, [-1438606420548203930])
[ Info: (8, 16, [-1438606420548203930, -7793123522309815574])
[ Info: (9, 17, [5871745011087148451])
[ Info: (11, 19, [-2976418037209735115])
[ Info: (2, 3, [2545187215397803953])
[ Info: (2, 4, [2545187215397803953, -1955610229054269401])

julia> MYVAR1 = ThreadRef(10)
ThreadRef{Int64}(Int64[], var"#11#12"{Int64, Int64}(10))

julia> MYVAR1[]
10

julia> MYVAR1[] = 20
20

julia> unsafe_load(pointer_from_objref(MYVAR1))
20

A similar structure TaskRef could be made for tasks using task_local_storage.

Thoughts? Comments?

[EDIT: corrected bug in isassigned]

1 Like