A weird operation to optimize

Hey,

The current culprit of my application is the following operation :

function weird_cumsum!(out,in,I,J)
    # This is equivalent to : 
    #       out .= cumsum(in[I])[J]
    # but hopefully a bit faster...
    # We can assume the vector of indices I and J to be non-decreasing
    # And we can assume that length(J) == length(in) == length(out) > length(I)

    # Moreover, we could assume that the content of the vectors I and J is "known":
    # in the sense that any precomputations depending on I and J but not on `in` can be assumed free. 



    out[1] = in[I[1]]
    for k in 2:length(I)
        out[k] = out[k-1]+in[I[k]]
    end
    for l in length(J):-1:1
        out[l] = out[J[l]]
    end
end

# here is some random data: 
N = 10000
M = 6000 # smaller than N
in = rand(N)
I =  sort(rand(1:N , M))
J = sort(rand(1:M, N))
out = similar(in)
weird_cumsum!(out,in,I,J)

With any precomputation on I,J that you like, can this be optimized further ? I am not sure it could be reduced to one loop only, but maybe there are still smart stuff to do here.

1 Like

This is 2.5 times faster on my device already:

function weird_cumsum_inter!(out,in,I,J)
    inter = in[I[1]]
    for k in eachindex(I)
        out[k] = inter
        k == length(I) && break
        inter += in[I[k]]
    end
    for l in length(J):-1:1
        out[l] = out[J[l]]
    end
    return out
end

but let me play around a little bit more!

I’ve tried a few things so far that seem to speed up the computations further:

using Pkg; Pkg.activate("."); Pkg.instantiate()

using BenchmarkTools, Random
Random.seed!(42)

function weird_cumsum!(out,in,I,J)
    out[1] = in[I[1]]
    for k in 2:length(I)
        out[k] = out[k-1]+in[I[k]]
    end
    for l in length(J):-1:1
        out[l] = out[J[l]]
    end
    return out
end

"""By carrying the intermediate sum in a variable, I think we avoid reading some values from memory + the first loop becomes simpler."""
function weird_cumsum_inter!(out,in,I,J)
    inter = in[I[1]]
    for k in eachindex(I)
        out[k] = inter
        k == length(I) && break
        inter += in[I[k]]
    end
    for l in length(J):-1:1
        out[l] = out[J[l]]
    end
    return out
end

"""Using an auxiliary variable to save the cumulative sum allows a nicer second for-loop. """
function weird_cumsum_inter_aux!(aux,out,in,I,J)
    inter = in[I[1]]
    for k in eachindex(I)
        aux[k] = inter
        k == length(I) && break
        inter += in[I[k]]
    end
    for l in eachindex(J) # If you convert J to a vector of indices, you can perhaps make this even faster!
        out[l] = aux[J[l]]
    end
    return out
end

"""
This makes better use of the problem specifics but I don't think the added overhead easily outperforms the previous one. 
(This is currently slightly slower for me!)
I do think it is possible to do the same thing but looping over J on the outside, but again I'm not sure if that is faster.
"""
function weird_cumsum_inter_aux_incr!(aux,out,in,I,J)
    inter = in[I[1]]
    j_curr = 1
    j_index = 1
    for k in eachindex(I)
        aux[k] = inter
        k == length(I) && break
        inter += in[I[k]]
        while j_curr <= k
            if j_curr == k
                out[j_index] = aux[k]
            end
            j_index += 1
            j_curr = J[j_index]
        end
    end
    return out
end


function test()
    # here is some random data: 
    N = 10000
    M = 6000 # smaller than N
    in = rand(N)
    I = sort(rand(1:N , M))
    J = sort(rand(1:M, N))
    out = similar(in)
    aux = similar(in)
    display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter!(out,in,I,J))
    display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux!(aux,out,in,I,J))
    display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux_incr!(aux,out,in,I,J))

    display(@benchmark weird_cumsum!($out,$in,$I,$J))
    display(@benchmark weird_cumsum_inter!($out,$in,$I,$J))
    display(@benchmark weird_cumsum_inter_aux!($aux,$out,$in,$I,$J))
    display(@benchmark weird_cumsum_inter_aux_incr!($aux,$out,$in,$I,$J))
end
test()

which gives:

true
true
true
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  26.208 ΞΌs … 81.375 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     26.375 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   26.629 ΞΌs Β±  1.966 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–…β–ˆβ–„          ▁                                              ▁
  β–ˆβ–ˆβ–ˆβ–„β–„β–„β–„β–„β–„β–…β–…β–…β–ˆβ–ˆβ–‡β–†β–…β–ƒβ–…β–„β–…β–†β–ˆβ–‡β–„β–…β–ƒβ–ƒβ–…β–‡β–…β–…β–β–ƒβ–…β–„β–ƒβ–„β–β–…β–…β–†β–†β–„β–β–„β–ƒβ–„β–β–…β–…β–„β–„β–…β–„β–„β–…β–ƒβ–„ β–ˆ
  26.2 ΞΌs      Histogram: log(frequency) by time      33.2 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  10.791 ΞΌs … 51.292 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     11.208 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   11.300 ΞΌs Β±  1.492 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  ▁▆▄  β–ˆβ–†β–‚ β–„β–‚                                                 β–‚
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–…β–…β–…β–†β–†β–…β–„β–ƒβ–„β–…β–…β–ƒβ–…β–ƒβ–„β–ƒβ–β–β–„β–β–„β–ƒβ–„β–β–„β–β–„β–„β–…β–„β–„β–β–„β–ƒβ–β–β–„β–„β–ƒβ–„β–β–ƒβ–β–„β–ƒβ–…β–„ β–ˆ
  10.8 ΞΌs      Histogram: log(frequency) by time      15.1 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.
BenchmarkTools.Trial: 10000 samples with 3 evaluations per sample.
 Range (min … max):  8.430 ΞΌs …  22.944 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     8.695 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   8.791 ΞΌs Β± 699.367 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  ▄▁   β–ˆβ–ˆβ–ƒ  β–‚β–ƒ                                                β–‚
  β–ˆβ–ˆβ–ƒβ–„β–β–ˆβ–ˆβ–ˆβ–†β–ƒβ–ˆβ–ˆβ–…β–†β–…β–…β–…β–…β–‡β–„β–ƒβ–„β–…β–…β–„β–…β–„β–„β–†β–†β–‡β–…β–…β–…β–†β–„β–†β–…β–…β–„β–†β–†β–…β–…β–†β–„β–„β–…β–…β–„β–…β–„β–β–ƒβ–…β–ƒβ–„β–…β–„ β–ˆ
  8.43 ΞΌs      Histogram: log(frequency) by time      11.1 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):  18.750 ΞΌs … 64.667 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     19.625 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   19.841 ΞΌs Β±  1.696 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

             ▁▁ ▁    β–†β–ˆβ–„β–                                      
  β–‚β–ƒβ–…β–…β–„β–ƒβ–‚β–‚β–„β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–…β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–„β–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–„β–„β–„β–ƒβ–ƒβ–‚β–ƒβ–„β–…β–„β–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–‚β–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–ƒβ–‚β–β–β– β–ƒ
  18.8 ΞΌs         Histogram: frequency by time        21.2 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.

i.e. on my device, this brings it down to 8 ΞΌs from 26. This is the best I can come up with for now!

Is this correct?

Without J having specific structure, this could cause weird duplication of entries?

@Dan Yep its correct, J has a specific structure (in fact J[k] <= k \forall k)

@JADekker Thanks for the experiments! Let me take a bit of time to process them..

1 Like

Just to clarify, if length(in) = 10, then J could be something like:

J = [ 1, 1, 1, 2, 2, 3, 4, 4, 5, 6]

Essentially, a subsequence of cumsum-ed in. In this case, running through J from the start and calculating the cumsum up to the next required index, without ever writing intermediates to memory should be faster.

1 Like

This version is a bit more idiomatic:

"""An even faster function"""
function weird_cumsum_inter_aux_2!(aux, out, in, I, J)
    inter = zero(eltype(in)) # Initialize inter with zero of the element type
    for k in eachindex(I)
        inter += in[I[k]]
        aux[k] = inter
    end
    for l in eachindex(J)
        out[l] = aux[J[l]]
    end
    return out
end

but has similar runtime for me

1 Like

Yes that is the case. In fact J[end] is about half as length(J) == length(in)

1 Like

So definitely no need to calculate cumsum beyond entry I[J[end]]. That’s a 2x speedup right there.

1 Like

This is another version:

function weird_cumsum2!(out,in,I,J)
    s = zero(eltype(out))
    lastj = 0
    local j
    for i in eachindex(J)
        for outer j in lastj+1:J[i]
            s += in[I[j]]
        end
        lastj = j
        out[i] = s
    end
end

@JADekker , if you benchmark it against your versions, we can find if it is faster.

It’s slightly slower on my device, but I don’t get the same results as the other functions

Anyway:

function weird_cumsum2!(out,in,I,J)
    s = zero(eltype(out))
    lastj = 0
    local j
    for i in eachindex(J)
        for outer j in lastj+1:J[i]
            s += in[I[j]]
        end
        lastj = j
        out[i] = s
    end
end

gives

BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
 Range (min … max):   9.041 ΞΌs … 74.625 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):      9.375 ΞΌs              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   10.172 ΞΌs Β±  2.756 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆβ–†β–ˆβ–ƒ                     β–‚β–„   β–„β–ƒ                            ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–ˆβ–‡β–†β–†β–…β–„β–…β–„β–†β–…β–…β–…β–…β–…β–…β–‡β–‡β–…β–‡β–ˆβ–ˆβ–‡β–†β–†β–ˆβ–ˆβ–‡β–…β–„β–…β–†β–…β–„β–ƒβ–„β–„β–ˆβ–ˆβ–…β–‚β–…β–„β–‚β–ƒβ–„β–„β–„β–ƒβ–„β–…β–‚β–‚β–„ β–ˆ
  9.04 ΞΌs      Histogram: log(frequency) by time      18.3 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.
1 Like

Full code:

using Pkg; Pkg.activate("."); Pkg.instantiate()

using BenchmarkTools, Random
Random.seed!(42)

function weird_cumsum!(out,in,I,J)
    out[1] = in[I[1]]
    for k in 2:length(I)
        out[k] = out[k-1]+in[I[k]]
    end
    for l in length(J):-1:1
        out[l] = out[J[l]]
    end
    return out
end

"""By carrying the intermediate sum in a variable, I think we avoid reading some values from memory + the first loop becomes simpler."""
function weird_cumsum_inter!(out,in,I,J)
    inter = in[I[1]]
    for k in eachindex(I)
        out[k] = inter
        k == length(I) && break
        inter += in[I[k]]
    end
    for l in length(J):-1:1
        out[l] = out[J[l]]
    end
    return out
end

"""Using an auxiliary variable to save the cumulative sum allows a nicer second for-loop. """
function weird_cumsum_inter_aux!(aux,out,in,I,J)
    inter = in[I[1]]
    for k in eachindex(I)
        aux[k] = inter
        k == length(I) && break
        inter += in[I[k]]
    end
    for l in eachindex(J) # If you convert J to a vector of indices, you can perhaps make this even faster!
        out[l] = aux[J[l]]
    end
    return out
end

"""An even faster function"""
function weird_cumsum_inter_aux_2!(aux, out, in, I, J)
    inter = zero(eltype(in)) # Initialize inter with zero of the element type
    @inbounds for k in eachindex(I)
        inter += in[I[k]]
        aux[k] = inter
    end
    @inbounds for l in eachindex(J)
        out[l] = aux[J[l]]
    end
    return out
end

function weird_cumsum2!(out,in,I,J)
    s = zero(eltype(out))
    lastj = 0
    local j
    for i in eachindex(J)
        for outer j in lastj+1:J[i]
            s += in[I[j]]
        end
        lastj = j
        out[i] = s
    end
end

"""
This makes better use of the problem specifics but I don't think the added overhead easily outperforms the previous one. 
(This is currently slightly slower for me!)
I do think it is possible to do the same thing but looping over J on the outside, but again I'm not sure if that is faster.
"""
function weird_cumsum_inter_aux_incr!(aux,out,in,I,J)
    inter = in[I[1]]
    j_curr = 1
    j_index = 1
    for k in eachindex(I)
        aux[k] = inter
        k == length(I) && break
        inter += in[I[k]]
        while j_curr <= k
            if j_curr == k
                out[j_index] = aux[k]
            end
            j_index += 1
            j_curr = J[j_index]
        end
    end
    return out
end


function test()
    # here is some random data: 
    N = 10000
    M = 6000 # smaller than N
    in = rand(N)
    I = sort(rand(1:N , M))
    J = sort(rand(1:M, N))
    out = similar(in)
    aux = similar(in)
    display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter!(out,in,I,J))
    display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux!(aux,out,in,I,J))
    display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux_2!(aux,out,in,I,J))
    display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux_incr!(aux,out,in,I,J))
    display(weird_cumsum!(out,in,I,J) == weird_cumsum2!(out,in,I,J))

    #display(@benchmark weird_cumsum!($out,$in,$I,$J))
    #display(@benchmark weird_cumsum_inter!($out,$in,$I,$J))
    #display(@benchmark weird_cumsum_inter_aux!($aux,$out,$in,$I,$J))
    display(@benchmark weird_cumsum_inter_aux_2!($aux,$out,$in,$I,$J))
    display(@benchmark weird_cumsum2!($out,$in,$I,$J))
    #display(@benchmark weird_cumsum_inter_aux_incr!($aux,$out,$in,$I,$J))

end

test()

Adapting your solutions to my application (this MWE is not exaclty what i needed but whatever) I do get a significant improvement. This function moved from 80% or runtime to 20% which is great. thanks a lot :slight_smile:

3 Likes

This is another version, in an attempt to make the code neater:

function weird_cumsum3!(out,in,I,J)
    Jrng = 1:J[end]
    @views out[Jrng] .= in[I[Jrng]]
    @views cumsum!(out[Jrng], out[Jrng])
    @views reverse(out) .= out[reverse(J)]
    return nothing
end

It suffers from some tiny allocations. Not sure where they stem from. But is still quite fast. Perhaps for large lengths, cumsum! might have optimizations.