Hello, I just tried your sum from the PR, and I’ve generalized your ideas using @generated functions: implementing accumulators (for your Base.jl and reduce_mb_4 versions), and incorporating tree reduction within the loop itself (reduce_mb_3).
Here is my attempt:
macro tree_reduce(op, prefix, N, M, stride)
prefix_str = string(prefix)
function build_tree(indices)
count = length(indices)
if count == 1
return Symbol(prefix_str, :_, indices[1])
elseif count == 2
return :($op($(Symbol(prefix_str, :_, indices[1])),
$(Symbol(prefix_str, :_, indices[2]))))
else
mid = count ÷ 2
left = build_tree(indices[1:mid])
right = build_tree(indices[mid+1:end])
return :($op($left, $right))
end
end
# Generate strided indices: N, N+stride, N+2*stride, ..., N+(M-1)*stride
indices = [N + i * stride for i in 0:M-1]
return esc(build_tree(indices))
end
@generated function reduce_abstracted(op, A, ::Val{Ntot}, ::Val{Nacc}) where {Ntot,Nacc}
Pacc = trailing_zeros(Nacc)
Ptot = trailing_zeros(Ntot)
quote
f = identity
inds = eachindex(A)
i1, iN = firstindex(inds), lastindex(inds)
n = length(inds)
@nexprs $Ntot N -> a_N = @inbounds A[inds[i1+(N-1)]]
@nexprs $Nacc N -> v_N = @tree_reduce(op, a, N, $(cld(Ntot, Nacc)), $Nacc)
for batch in 1:(n>>$Ptot)-1
i = 1 + (batch << $Ptot)
@nexprs $Ntot N -> begin
a_N = @inbounds A[inds[i+(N-1)]]
end
@nexprs $Nacc N -> v_N = op(v_N, @tree_reduce(op, a, N, $(cld(Ntot, Nacc)), $Nacc))
end
v = @tree_reduce(op, v, 1, $Nacc, 1)
i = i1 + (n >> $Ptot) * $Ntot - 1
i == iN && return v
for i in i+1:iN
ai = @inbounds A[inds[i]]
v = op(v, f(ai))
end
return v
end
end
The Base.sum from your PR correspond to Nacc=8 (8 accumulators) and Ntot=8 (no intern tree reduction). You mb_3 function correspond to Ntot=16 and Nacc=4.
Here are my benchmarks:
n = 2^15
a = rand(Float64, n)
for Ntot in (4, 8, 16, 32)
for Nacc in (1, 2, 4, 8)
if Nacc <= Ntot
println("n=$n, a[1:n], view(a, 1:2:n) (above, below) Ntot=$Ntot, Nacc=$Nacc")
@btime reduce_abstracted(+, $a, $(Val(Ntot)), $(Val(Nacc)))
@btime reduce_abstracted(+, $(view(a, 1:2:n)), $(Val(Ntot)), $(Val(Nacc)))
end
end
end
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=4, Nacc=1
3.294 μs (0 allocations: 0 bytes)
2.475 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=4, Nacc=2
3.405 μs (0 allocations: 0 bytes)
2.486 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=4, Nacc=4
3.649 μs (0 allocations: 0 bytes)
2.449 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=8, Nacc=1
4.808 μs (0 allocations: 0 bytes)
2.130 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=8, Nacc=2
2.396 μs (0 allocations: 0 bytes)
2.181 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=8, Nacc=4
2.147 μs (0 allocations: 0 bytes)
4.372 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=8, Nacc=8
3.427 μs (0 allocations: 0 bytes)
2.225 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=16, Nacc=1
4.323 μs (0 allocations: 0 bytes)
2.227 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=16, Nacc=2
2.061 μs (0 allocations: 0 bytes)
2.123 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=16, Nacc=4
1.739 μs (0 allocations: 0 bytes)
4.345 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=16, Nacc=8
2.468 μs (0 allocations: 0 bytes)
4.318 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=32, Nacc=1
4.236 μs (0 allocations: 0 bytes)
2.214 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=32, Nacc=2
2.270 μs (0 allocations: 0 bytes)
2.369 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=32, Nacc=4
1.348 μs (0 allocations: 0 bytes)
3.978 μs (0 allocations: 0 bytes)
n=32768, a[1:n], view(a, 1:2:n) (above, below) Ntot=32, Nacc=8
1.651 μs (0 allocations: 0 bytes)
3.959 μs (0 allocations: 0 bytes)
For comparison, the Base.sum from your PR on Float64:
@btime Base.sum($a)
@btime Base.sum($(view(a, 1:2:n)))
3.459 μs (0 allocations: 0 bytes)
16295.798570780385
4.040 μs (0 allocations: 0 bytes)
8128.148967556634
This is different from what I got because of the additionnal pairwise reduction I guess.
Here is what the current Base.sum doing on Float64:
2.488 μs (0 allocations: 0 bytes)
16410.273204764926
6.539 μs (0 allocations: 0 bytes)
8170.225306474282
and just for completeness, the results from @simd:
a=rand(Float32, 2^15)
function simdsum(a::AbstractArray{T}) where T
s = T(0)
@simd for x in a
s += x
end
return s
end
@btime simdsum($a)
@btime simdsum($(view(a, 1:2:2^15)))
1.058 μs (0 allocations: 0 bytes)
16345.918f0
6.549 μs (0 allocations: 0 bytes)
8140.3384f0
On my Dell machine, the current sum implementation underperforms for non-view floating-point arrays. I’m curious whether you observe similar benchmark trends on your system:
- Val(16), Val(2): Good all-around performance across array types
- Val(32), Val(4): Optimal performance specifically for contiguous bitstype arrays
I’ve found that @simd performs poorly on views and larger element types, but on my machine it’s unbeatable for contiguous bitstype arrays—with one exception: very small arrays benefit more from vectorized loads followed by reduction.
For context, I started a discussion on this topic a few days ago (before discovering this thread):
I also documented my initial attempt to outperform mapreduce on contiguous bitstype arrays using @simd:
https://epilliat.github.io/coding/notebooks/mapreduce_cpu_perf.html
A key finding: alignment is critical for strong @simd performance, at least on my hardware. The current Base.sum implementation doesn’t properly initialize accumulators for 1024-element blocks, which significantly degrades performance.