Hi, in my code, I need to determine whether sum(x) <= a, where I know that the values x are all positive, so concluding false is in theory possible before weβve added all values in x together. This led me to wonder if it is possible to write a function that checks sum(x) <= a (see my sum_le) that outperforms simply checking sum(x) <= a (see my sum_le_base). A naive implementation does not seem to stand a chance because it is clearly less efficiently implemented than the sum function. I currently donβt need this as Iβm adding at most a few dozen numbers together in my actual application, but I was just wondering if this can be done, or if the penalty from checking s>0 repeatedly is too high.
using Random, Distributions, BenchmarkTools
Random.seed!(42)
"""Determines if the sum of the elements in the list is less than or equal to a given value."""
sum_le_base(x, a) = sum(x) β€ a
"""Quick version of sum_le_base."""
function sum_le(x, a)
s = -a
for xi in x
s += xi
if s > 0
return false
end
end
return true
end
function RunBench(N)
X = rand(N)
r = 0.5
display(sum_le(X, r*N) == sum_le_base(X, r*N))
display(@benchmark sum_le($X, $r*$N))
display(@benchmark sum_le_base($X, $r*$N))
end
RunBench(1000)
This obviously depends on the statistics of your data (i.e. how often the checked loop terminates early).
Assuming the checked loop typically runs for quite a few iterations, you could still get the best of both worlds by just checking occasionally, e.g. every 2^k-th element for some k (whose optimal choice will depend on your data statistics).
Note that there is no magic in the built-in sum function in terms of efficiency β you can get basically the same speed with a @simd for loop (although it will be less accurate for floating-point numbers, where sum uses pairwise summation).
Is @simd allowed here? I checked the docs the other day and thought @simd assumed that iterations could be done overlappingly, which would be a problem here, right? Or did I misunderstand the docs there?
I canβt come anywhere close to the performance of sum by using @simd on a for loop but I can see how there may be some optimization under the hood that makes sum works so quickly!
using Random, Distributions, BenchmarkTools
Random.seed!(42)
"""Determines if the sum of the elements in the list is less than or equal to a given value."""
sum_le_base(x, a) = sum(x) β€ a
"""Quick version of sum_le_base."""
function sum_le(x, a)
s = -a
for xi in x
s += xi
if s > 0
return false
end
end
return true
end
"""Quick version of sum_le_base."""
function sum_le_direct(x, a)
s = 0
@inbounds @simd for i in eachindex(x)
s += x[i]
end
return s <= a
end
function RunBench(N)
X = rand(N)
r = 0.5
display(sum_le(X, r*N) == sum_le_base(X, r*N) == sum_le_direct(X, r*N))
display(@benchmark sum_le($X, $r*$N))
display(@benchmark sum_le_base($X, $r*$N))
display(@benchmark sum_le_direct($X, $r*$N))
end
RunBench(1000)
This is type-unstable. Your s starts as an integer but ends up as a floating-point value (for floating-point arrays). (See the performance tips.)
If you fix this, it is as fast as sum (actually slightly faster because it avoids the slight overhead of pairwise summation):
julia> function mysum(a)
s = zero(eltype(a))
@simd for i in eachindex(a)
s += a[i]
end
return s
end
mysum (generic function with 1 method)
julia> a = rand(1000); @btime sum($a); @btime mysum($a);
97.765 ns (0 allocations: 0 bytes)
93.361 ns (0 allocations: 0 bytes)
(The @inbounds is inferred here, though it wouldnβt hurt to add it.)
It probably wonβt help if you have the s > a check in every loop iteration, but it will help if you do the loop in chunks.