Below is an example of using atomics. However, updates such as @atomic om.x += 1 is implemented with a loop in LLVM (since neither armv8 or x86 has dedicated instructions for this), so on my Mac, under contention, using atomics is much slower than using a lock.
Here is the atomic version:
mutable struct OnlineMean
count::Threads.Atomic{UInt64}
sum::Threads.Atomic{UInt64}
OnlineMean() = new(Threads.Atomic{UInt64}(0), Threads.Atomic{UInt64}(0))
end
mean(om::OnlineMean) = om.sum[] / om.count[]
# This is thread-safe. Multiple tasks can increment om.count
# before om.sum, but that will not corrupt the data in `om`,
# as both fields will be updated correctly eventually.
function update!(om::OnlineMean, x::Integer)
ux = UInt64(x)
# Monotonic is enough here, since we only care about the operation
# itself being atomic, and don't care about its ordering w.r.t other
# operations
Threads.atomic_add!(om.count, UInt64(1))
Threads.atomic_add!(om.sum, ux)
om
end
function foo()
om = OnlineMean()
# THis loop is thread-safe
Threads.@threads for _ in 1:10000
for _ in 1:1000
obs = rand(UInt(0):UInt(100))
update!(om, obs)
end
end
# NOTE: Computing the mean IS NOT threadsafe concurrently
# with calling `update!`, since the om.count and om.sum fields
# may be updated at different times! However, once the loop
# above is complete, it's in the right state
return mean(om)
end
And here with a lock:
mutable struct OnlineMean2
count::UInt64
sum::UInt64
OnlineMean2() = new(0, 0.0)
end
mean(om::OnlineMean2) = om.sum / om.count
function update!(om::Lockable{OnlineMean2}, x::Integer)
ux = UInt64(x)
@lock om begin
om[].count += 1
om[].sum += ux
end
om
end
function foo2()
om = Lockable(OnlineMean2())
Threads.@threads for _ in 1:10000
for _ in 1:1000
obs = rand(UInt(0):UInt(100))
update!(om, obs)
end
end
@lock om mean(om[])
end
Finally, the fastest result for this case is to have each thread update seperate counters and then merge them at the end - on my computer this is 10x faster than either. But of course, the speed there depends on how expensive instantiation of these OnlineMean objects are - in this example, they’re very cheap but that might not be the case for your real world code.
function update!(om::OnlineMean2, x::Integer)
om.count += 1
om.sum += UInt64(x)
om
end
function foo3()
means = [OnlineMean2() for thread in 1:10000]
Threads.@threads for i in eachindex(means)
om = means[i]
for _ in 1:1000
obs = rand(UInt(0):UInt(100))
update!(om, obs)
end
end
om = means[1]
for other_om in @view means[2:end]
om.count += other_om.count
om.sum += other_om.sum
end
mean(om)
end