Quick check that sum of positive values is at most threshold

Hi, in my code, I need to determine whether sum(x) <= a, where I know that the values x are all positive, so concluding false is in theory possible before we’ve added all values in x together. This led me to wonder if it is possible to write a function that checks sum(x) <= a (see my sum_le) that outperforms simply checking sum(x) <= a (see my sum_le_base). A naive implementation does not seem to stand a chance because it is clearly less efficiently implemented than the sum function. I currently don’t need this as I’m adding at most a few dozen numbers together in my actual application, but I was just wondering if this can be done, or if the penalty from checking s>0 repeatedly is too high.

using Random, Distributions, BenchmarkTools
Random.seed!(42)

"""Determines if the sum of the elements in the list is less than or equal to a given value."""
sum_le_base(x, a) = sum(x) ≀ a

"""Quick version of sum_le_base."""
function sum_le(x, a)
    s = -a 
    for xi in x
        s += xi
        if s > 0
            return false
        end
    end
    return true
end

function RunBench(N)
    X = rand(N)
    r = 0.5
    display(sum_le(X, r*N) == sum_le_base(X, r*N))
    display(@benchmark sum_le($X, $r*$N))
    display(@benchmark sum_le_base($X, $r*$N))
end
RunBench(1000)

This gives

true
BenchmarkTools.Trial: 10000 samples with 77 evaluations.
 Range (min … max):  831.701 ns …  1.951 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     833.338 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   853.481 ns Β± 63.417 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆ       β–„     ▁                                              ▁
  β–ˆβ–†β–†β–‡β–†β–…β–…β–…β–ˆβ–‡β–†β–†β–‡β–†β–ˆβ–‡β–‡β–†β–†β–…β–†β–†β–„β–…β–…β–†β–†β–†β–†β–…β–†β–†β–„β–†β–…β–†β–†β–†β–†β–†β–…β–…β–…β–…β–…β–…β–…β–…β–…β–†β–„β–„β–…β–†β–…β–„β–…β–…β–ƒβ–… β–ˆ
  832 ns        Histogram: log(frequency) by time      1.15 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.
BenchmarkTools.Trial: 10000 samples with 957 evaluations.
 Range (min … max):  89.821 ns … 239.856 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     90.343 ns               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   93.016 ns Β±   7.106 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆβ–†   β–„ β–ƒ  β–ƒ     ▁                                            ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–†β–ˆβ–†β–ˆβ–ˆβ–†β–ˆβ–‡β–‡β–‡β–ˆβ–ˆβ–‡β–‡β–†β–‡β–ˆβ–‡β–‡β–‡β–†β–‡β–†β–‡β–‡β–‡β–‡β–‡β–‡β–‡β–†β–‡β–‡β–†β–†β–†β–†β–†β–†β–†β–†β–†β–†β–…β–†β–…β–…β–…β–…β–„β–…β–„β–ƒβ–„ β–ˆ
  89.8 ns       Histogram: log(frequency) by time       122 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

This obviously depends on the statistics of your data (i.e. how often the checked loop terminates early).

Assuming the checked loop typically runs for quite a few iterations, you could still get the best of both worlds by just checking occasionally, e.g. every 2^k-th element for some k (whose optimal choice will depend on your data statistics).

Note that there is no magic in the built-in sum function in terms of efficiency β€” you can get basically the same speed with a @simd for loop (although it will be less accurate for floating-point numbers, where sum uses pairwise summation).

1 Like

I see, thank you for your elaborate response!

Is @simd allowed here? I checked the docs the other day and thought @simd assumed that iterations could be done overlappingly, which would be a problem here, right? Or did I misunderstand the docs there?

I can’t come anywhere close to the performance of sum by using @simd on a for loop but I can see how there may be some optimization under the hood that makes sum works so quickly!

using Random, Distributions, BenchmarkTools
Random.seed!(42)

"""Determines if the sum of the elements in the list is less than or equal to a given value."""
sum_le_base(x, a) = sum(x) ≀ a

"""Quick version of sum_le_base."""
function sum_le(x, a)
    s = -a 
    for xi in x
        s += xi
        if s > 0
            return false
        end
    end
    return true
end

"""Quick version of sum_le_base."""
function sum_le_direct(x, a)
    s = 0
    @inbounds @simd for i in eachindex(x)
        s += x[i]
    end
    return s <= a
end

function RunBench(N)
    X = rand(N)
    r = 0.5
    display(sum_le(X, r*N) == sum_le_base(X, r*N) == sum_le_direct(X, r*N))
    display(@benchmark sum_le($X, $r*$N))
    display(@benchmark sum_le_base($X, $r*$N))
    display(@benchmark sum_le_direct($X, $r*$N))
end
RunBench(1000)

gives

true
BenchmarkTools.Trial: 10000 samples with 77 evaluations.
 Range (min … max):  831.701 ns …  1.478 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     833.325 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   853.810 ns Β± 61.965 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆ       β–‚β–‚                                                   ▁
  β–ˆβ–…β–†β–†β–†β–…β–…β–…β–ˆβ–ˆβ–†β–†β–‡β–†β–…β–‡β–‡β–‡β–†β–†β–†β–…β–†β–…β–†β–‡β–…β–‡β–‡β–†β–†β–†β–†β–†β–†β–†β–†β–†β–†β–†β–†β–†β–†β–…β–†β–…β–†β–†β–…β–…β–†β–…β–†β–…β–…β–„β–…β–…β–„β–„ β–ˆ
  832 ns        Histogram: log(frequency) by time      1.13 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.
BenchmarkTools.Trial: 10000 samples with 957 evaluations.
 Range (min … max):  89.603 ns … 445.794 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     90.125 ns               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   93.783 ns Β±  10.628 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆβ– ▁▄▃ ▂▁  ▁▁                                                ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–‡β–ˆβ–ˆβ–‡β–‡β–ˆβ–‡β–ˆβ–‡β–‡β–‡β–ˆβ–‡β–‡β–‡β–‡β–‡β–‡β–‡β–†β–†β–†β–†β–…β–‡β–†β–…β–…β–…β–…β–…β–…β–…β–…β–…β–…β–„β–„β–„β–ƒβ–„β–„β–„β–ƒβ–„ β–ˆ
  89.6 ns       Histogram: log(frequency) by time       134 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.
BenchmarkTools.Trial: 10000 samples with 127 evaluations.
 Range (min … max):  738.843 ns …  1.263 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     739.827 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   758.865 ns Β± 46.426 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆ     β–‚β–‚  β–ƒβ–ƒ                                                 ▁
  β–ˆβ–‡β–‡β–ˆβ–†β–‡β–ˆβ–ˆβ–‡β–†β–ˆβ–ˆβ–‡β–‡β–‡β–‡β–‡β–ˆβ–ˆβ–ˆβ–‡β–†β–†β–‡β–†β–†β–‡β–‡β–†β–†β–†β–†β–†β–†β–†β–‡β–†β–†β–‡β–†β–‡β–†β–†β–†β–†β–†β–‡β–†β–…β–…β–„β–…β–„β–…β–…β–„β–…β–…β–„β–… β–ˆ
  739 ns        Histogram: log(frequency) by time       963 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

This is type-unstable. Your s starts as an integer but ends up as a floating-point value (for floating-point arrays). (See the performance tips.)

If you fix this, it is as fast as sum (actually slightly faster because it avoids the slight overhead of pairwise summation):

julia> function mysum(a)
           s = zero(eltype(a))
           @simd for i in eachindex(a)
               s += a[i]
           end
           return s
       end
mysum (generic function with 1 method)

julia> a = rand(1000); @btime sum($a); @btime mysum($a);
  97.765 ns (0 allocations: 0 bytes)
  93.361 ns (0 allocations: 0 bytes)

(The @inbounds is inferred here, though it wouldn’t hurt to add it.)

It probably won’t help if you have the s > a check in every loop iteration, but it will help if you do the loop in chunks.

1 Like

I see, thank you! I missed the type instability there indeed, no clue why I didn’t spot that…