Full code:
using Pkg; Pkg.activate("."); Pkg.instantiate()
using BenchmarkTools, Random
Random.seed!(42)
function weird_cumsum!(out,in,I,J)
out[1] = in[I[1]]
for k in 2:length(I)
out[k] = out[k-1]+in[I[k]]
end
for l in length(J):-1:1
out[l] = out[J[l]]
end
return out
end
"""By carrying the intermediate sum in a variable, I think we avoid reading some values from memory + the first loop becomes simpler."""
function weird_cumsum_inter!(out,in,I,J)
inter = in[I[1]]
for k in eachindex(I)
out[k] = inter
k == length(I) && break
inter += in[I[k]]
end
for l in length(J):-1:1
out[l] = out[J[l]]
end
return out
end
"""Using an auxiliary variable to save the cumulative sum allows a nicer second for-loop. """
function weird_cumsum_inter_aux!(aux,out,in,I,J)
inter = in[I[1]]
for k in eachindex(I)
aux[k] = inter
k == length(I) && break
inter += in[I[k]]
end
for l in eachindex(J) # If you convert J to a vector of indices, you can perhaps make this even faster!
out[l] = aux[J[l]]
end
return out
end
"""An even faster function"""
function weird_cumsum_inter_aux_2!(aux, out, in, I, J)
inter = zero(eltype(in)) # Initialize inter with zero of the element type
@inbounds for k in eachindex(I)
inter += in[I[k]]
aux[k] = inter
end
@inbounds for l in eachindex(J)
out[l] = aux[J[l]]
end
return out
end
function weird_cumsum2!(out,in,I,J)
s = zero(eltype(out))
lastj = 0
local j
for i in eachindex(J)
for outer j in lastj+1:J[i]
s += in[I[j]]
end
lastj = j
out[i] = s
end
end
"""
This makes better use of the problem specifics but I don't think the added overhead easily outperforms the previous one.
(This is currently slightly slower for me!)
I do think it is possible to do the same thing but looping over J on the outside, but again I'm not sure if that is faster.
"""
function weird_cumsum_inter_aux_incr!(aux,out,in,I,J)
inter = in[I[1]]
j_curr = 1
j_index = 1
for k in eachindex(I)
aux[k] = inter
k == length(I) && break
inter += in[I[k]]
while j_curr <= k
if j_curr == k
out[j_index] = aux[k]
end
j_index += 1
j_curr = J[j_index]
end
end
return out
end
function test()
# here is some random data:
N = 10000
M = 6000 # smaller than N
in = rand(N)
I = sort(rand(1:N , M))
J = sort(rand(1:M, N))
out = similar(in)
aux = similar(in)
display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter!(out,in,I,J))
display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux!(aux,out,in,I,J))
display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux_2!(aux,out,in,I,J))
display(weird_cumsum!(out,in,I,J) == weird_cumsum_inter_aux_incr!(aux,out,in,I,J))
display(weird_cumsum!(out,in,I,J) == weird_cumsum2!(out,in,I,J))
#display(@benchmark weird_cumsum!($out,$in,$I,$J))
#display(@benchmark weird_cumsum_inter!($out,$in,$I,$J))
#display(@benchmark weird_cumsum_inter_aux!($aux,$out,$in,$I,$J))
display(@benchmark weird_cumsum_inter_aux_2!($aux,$out,$in,$I,$J))
display(@benchmark weird_cumsum2!($out,$in,$I,$J))
#display(@benchmark weird_cumsum_inter_aux_incr!($aux,$out,$in,$I,$J))
end
test()