Multi-threaded processing of a Dict

I have a function that iterates over all key-value pairs, does a computation based on the key and the value and then updates the value at that key. Due to the fact that in each iteration I only update the value and don’t do any insertions or deletions it seemed to me, that this should be amenable to multi-threading.

Namely, I thought of splitting the dictionary into as many views as I have threads and then let each thread work through its view, as shown below:

# This whole DictView logic is adapted from SplittableBase.jl
struct DictView{D}
    dict::D
    firstslot::Int
    lastslot::Int
end

DictView(xs::DictView, i::Int, j::Int) = DictView(xs.dict, i, j)

Base.IteratorEltype(::Type{<:DictView{D}}) where {D} = Base.IteratorEltype(D)
Base.IteratorSize(::Type{<:DictView}) = Base.SizeUnknown()

Base.eltype(::Type{<:DictView{D}}) where {D} = eltype(D)

function Base.length(xs::DictView)
    n = 0
    for _ in xs
        n += 1
    end
    return n
end

# Note: this relies on the implementation detail of `iterate(::Dict)`.
@inline function Base.iterate(xs::DictView, i = xs.firstslot)
    i <= xs.lastslot || return nothing
    y = iterate(xs.dict, i)
    y === nothing && return nothing
    x, j = y
    # If `j` is `xs.lastslot + 1` or smaller, it means the current element is
    # within the range of this `DictView`:
    j <= xs.lastslot + 1 && return x, j
    # Otherwise, we need to stop:
    return nothing
end

# these are not defined in SplittableBase.jl. Maybe for a good reason?
Base.setindex!(dv::DictView, v, k) = setindex!(dv.dict, v, k)
Base.getindex(dv::DictView, k) = getindex(dv.dict, k)

function partition(d::Dict, n::Integer)
    i0 = d.idxfloor
    in = lastindex(d.slots)
    slot_boundaries = round.(Int, range(i0-1, in; length=n + 1))
    views = [DictView(d, slot_boundaries[i]+1, slot_boundaries[i+1]) for i in 1:n]
    return views
end

function testfun_threaded(dict, n_threads)
    parts = partition(dict, n_threads)
    Threads.@threads for part in parts
        for (k, v) in part
            part[k] = log(k) * sin(v)
        end
    end
end

function testfun_unthreaded(dict)
    for (k, v) in dict
        dict[k] = log(k) * sin(v)
    end
end

Benchmarking this code gives only very modest speed-ups when increasing the number of parts that are processed in parallel:

dict = Dict(2i => rand() for i in 1:10_000_000)
@time testfun_dict(dict, 1)
#  0.897201 seconds (590 allocations: 13.641 KiB)

@time testfun_dict(dict, 10)
#  0.585802 seconds (404 allocations: 11.016 KiB)

@time testfun_unthreaded(dict)
#  0.870252 seconds

I suspect that the naive implementations of setindex! and getindex that I added here lead to some concurrency issues, even though in principle this should be safe because each thread works on its own chunk of data.

For possible alternatives, I would also much prefer not having to make a copy of the data in the dict because it makes up the majority of my programs memory requirements.

First question.

How do you make the dictionary thread safe so that you can update it in PARALLEL?

Sounds like you need a lock so that other threads do not look up the dictionary while you “edit” it in parallel.

Use GitHub - wherrera10/ThreadSafeDicts.jl: Thread safe Julia Dict ?

Unless I’m missing something, I’d just keep it simple. On my laptop with a i9-10885H @ 5.30 GHz CPU. Most of the timings were ~200ms to process 10_000_000 pairs.

# Threads.nthreads() == 16
@inline foo(k, v) = log(k) * sin(v)

n = 10^7
dict = Dict(Base.OneTo(n) .=> rand(n))
pair_vec = collect(dict) # Threads.@threads doesn't like to work on iterators
@btime (Threads.@threads for (k, v) in $pair_vec
   dict[k] = foo(k, v)
end) setup=(dict = $dict)

217.092 ms (82 allocations: 8.19 KiB)

I don’t.

The hope was that due to the structure of the reads and writes each thread only works on its own part of the dictionary in a way that should not interfere with any of the other threads.

I did have a look at this, but it unfortunately won’t do because the function I need to compute is very fast (not as easy as log(k) * sin(v), but almost). That means there would be a a lot of waiting for the dict to unlock from each of the threads.

I was hoping to avoid copying or collecting the data in dict and to just directly operate on the dict from multiple threads. But it looks like that might not easily be possible (at least not without digging deeper into the internals of Dict)

Look at

julia> d=Dict(3=>4); @less d[3]=4
function setindex!(h::Dict{K,V}, v0, key::K) where V where K
    v = v0 isa V ? v0 : convert(V, v0)::V
    index, sh = ht_keyindex2_shorthash!(h, key)

    if index > 0
        h.age += 1
        @inbounds h.keys[index] = key
        @inbounds h.vals[index] = v
    else
        @inbounds _setindex!(h, v, key, -index, sh)
    end

    return h
end

You have a race condition on h.age.

Now, you may not care about the value of h.age, and luckily julia is not C, so your code is still correct except for leaving age in a bad state. (data races in julia afaiu have llvm-semantics and produce undef, and do not have the side-effect of being UB)

Now all your cores want exclusive ownership of the cache-line where h.age resides. This fucks up performance.

You need to change your API: You need to iterate over (index, key, value)-triplets of filled values (you should use the internal skip_deleted function), and then set the value via h.vals[index] = newValue.

The reason for using the internal skip_deleted function as opposed to doing the check yourself is compatibility with older julia versions that predate the storing of hash-tags inside the slot byte. You probably should also test against julia 1.0, in order to iron that out, in case I misremember the history of the internal API.

Since you’re touching internals, don’t forget to document that future julia versions may silently break your code. This is probably fine for you, so no need to sweat it.

1 Like

Thanks for the pointer to the skip_deleted function and the ability to directly access h.vals[index]! It looks like this should also remove one hashing operation, because now I don’t need to do any hashing when setting the new value.

And yeah, I will set strict compat bounds on this since I am very much relying on the internals of Dict here.