Thanks @pitsianis and @mikmoore – the algorithmic improvements are exactly what I was trying to do but couldn’t quite figure out. Breaking down the problem into parts helped a lot.
I made a running sum of a
for the b
contribution and made a fast dot product with LoopVectorization
for the c
contribution (maybe there are better convolution tools?). foo4!
is now ~54% faster than foo1!
.
function foo2!(out, a, b, c, j_min)
N = length(out)
## b section
a_sum = @views sum(a[j_min:end]) # Initialize a running sum variable
@inbounds for i in 1:j_min
out[i] += a_sum * b[i]
end
@inbounds for i in j_min+1:N
a_sum -= a[i-1] # Update the running sum
out[i] += a_sum * b[i]
end
## c section
@inbounds for j in eachindex(a)[j_min:end]
a_j = a[j]
@inbounds for i in eachindex(out, c)[begin:j]
# @show i, j, i - j + N
out[i] += a_j * c[i - j + N]
end
end
return nothing
end
using LoopVectorization
function myconv(x, x_start, y, y_end, n)
out = 0.0
@turbo for i in 0:n-1
out += x[x_start+i] * y[y_end-i]
end
return out
end
function foo3!(out, a, b, c, j_min)
N = length(out)
## b section
a_sum = @views sum(a[j_min:end]) # Initialize a running sum variable
@inbounds for i in 1:j_min
out[i] += a_sum * b[i]
end
@inbounds for i in j_min+1:N
a_sum -= a[i-1] # Update the running sum
out[i] += a_sum * b[i]
end
## c section
@inbounds for i in 1:j_min
out[i] += myconv(a, j_min, c, N - j_min + i, N - j_min + 1)
end
@inbounds for i in j_min+1:N-1
out[i] += myconv(a, i, c, N, N-i+1)
end
out[end] += a[end] * c[end]
return nothing
end
function foo4!(out, a, b, c, j_min)
N = length(out)
## combined b and c
a_sum = @views sum(a[j_min:end]) # Initialize a running sum variable
@inbounds for i in 1:j_min
out[i] += muladd(
a_sum,
b[i],
myconv(a, j_min, c, N - j_min + i, N - j_min + 1)
)
end
@inbounds for i in j_min+1:N-1
a_sum -= a[i-1] # Update the running sum
out[i] += muladd(
a_sum,
b[i],
myconv(a, i, c, N, N-i+1)
)
end
a_sum -= a[end-1]
out[end] += a_sum * b[end] + a[end] * c[end]
return nothing
end
out1 = zeros(N);
out2 = zeros(N);
out3 = zeros(N);
out4 = zeros(N);
foo1!(out1, a, b, c, j_min, N);
foo2!(out2, a, b, c, j_min);
foo3!(out3, a, b, c, j_min);
foo4!(out4, a, b, c, j_min);
out1 ≈ out2 ≈ out3 ≈ out4 # true
using BenchmarkTools
@btime $foo1!($out1, $a, $b, $c, $j_min, $N) # 17.333 μs (0 allocations: 0 bytes) (different computer)
@btime $foo2!($out2, $a, $b, $c, $j_min) # 13.416 μs (0 allocations: 0 bytes)
@btime $foo3!($out3, $a, $b, $c, $j_min) # 11.166 μs (0 allocations: 0 bytes)
@btime $foo4!($out4, $a, $b, $c, $j_min) # 11.250 μs (0 allocations: 0 bytes)