I have a function that iterates over all key-value pairs, does a computation based on the key and the value and then updates the value at that key. Due to the fact that in each iteration I only update the value and don’t do any insertions or deletions it seemed to me, that this should be amenable to multi-threading.
Namely, I thought of splitting the dictionary into as many views as I have threads and then let each thread work through its view, as shown below:
# This whole DictView logic is adapted from SplittableBase.jl
struct DictView{D}
dict::D
firstslot::Int
lastslot::Int
end
DictView(xs::DictView, i::Int, j::Int) = DictView(xs.dict, i, j)
Base.IteratorEltype(::Type{<:DictView{D}}) where {D} = Base.IteratorEltype(D)
Base.IteratorSize(::Type{<:DictView}) = Base.SizeUnknown()
Base.eltype(::Type{<:DictView{D}}) where {D} = eltype(D)
function Base.length(xs::DictView)
n = 0
for _ in xs
n += 1
end
return n
end
# Note: this relies on the implementation detail of `iterate(::Dict)`.
@inline function Base.iterate(xs::DictView, i = xs.firstslot)
i <= xs.lastslot || return nothing
y = iterate(xs.dict, i)
y === nothing && return nothing
x, j = y
# If `j` is `xs.lastslot + 1` or smaller, it means the current element is
# within the range of this `DictView`:
j <= xs.lastslot + 1 && return x, j
# Otherwise, we need to stop:
return nothing
end
# these are not defined in SplittableBase.jl. Maybe for a good reason?
Base.setindex!(dv::DictView, v, k) = setindex!(dv.dict, v, k)
Base.getindex(dv::DictView, k) = getindex(dv.dict, k)
function partition(d::Dict, n::Integer)
i0 = d.idxfloor
in = lastindex(d.slots)
slot_boundaries = round.(Int, range(i0-1, in; length=n + 1))
views = [DictView(d, slot_boundaries[i]+1, slot_boundaries[i+1]) for i in 1:n]
return views
end
function testfun_threaded(dict, n_threads)
parts = partition(dict, n_threads)
Threads.@threads for part in parts
for (k, v) in part
part[k] = log(k) * sin(v)
end
end
end
function testfun_unthreaded(dict)
for (k, v) in dict
dict[k] = log(k) * sin(v)
end
end
Benchmarking this code gives only very modest speed-ups when increasing the number of parts that are processed in parallel:
dict = Dict(2i => rand() for i in 1:10_000_000)
@time testfun_dict(dict, 1)
# 0.897201 seconds (590 allocations: 13.641 KiB)
@time testfun_dict(dict, 10)
# 0.585802 seconds (404 allocations: 11.016 KiB)
@time testfun_unthreaded(dict)
# 0.870252 seconds
I suspect that the naive implementations of setindex!
and getindex
that I added here lead to some concurrency issues, even though in principle this should be safe because each thread works on its own chunk of data.
For possible alternatives, I would also much prefer not having to make a copy of the data in the dict because it makes up the majority of my programs memory requirements.