using LoopVectorization, StrideArraysCore
using StrideArraysCore: static, static_length, zero_offsets
function conv3!(
_c::AbstractVector{T},
_a::AbstractVector{T},
_b::AbstractVector{T},
) where {T}
c = zero_offsets(_c)
a = zero_offsets(_a)
b = zero_offsets(_b)
I = static_length(c)
K = static_length(b)
J = I - K + static(1)
J < K && return conv3!(_c, _b, _a)
@turbo for i = 0:K-2
s = zero(T)
for k = 0:K-1
ib0 = (i >= k)
oa = ib0 ? a[i-k] : zero(T)
s += b[k] * oa
end
c[i] = s
end
@turbo inline=true for i = K-1:J-1
s = zero(T)
for k = 0:K-1
s += b[k] * a[i-k]
end
c[i] = s
end
@turbo for i = J:I-1
s = zero(T)
for k = 0:K-1
ib0 = (i < J+k)
oa = ib0 ? a[i-k] : zero(T)
s += b[k] * oa
end
c[i] = s
end
end
I get
julia> @benchmark conv3!($vO, $vA, $vB)
BenchmarkTools.Trial: 9568 samples with 188 evaluations.
Range (min β¦ max): 546.271 ns β¦ 595.362 ns β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 547.266 ns β GC (median): 0.00%
Time (mean Β± Ο): 547.850 ns Β± 1.916 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
βββ
βββββββ βββββββββββ β
β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
546 ns Histogram: log(frequency) by time 554 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark _Conv1D!($vO, $vA, $vB)
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
Range (min β¦ max): 1.251 ΞΌs β¦ 1.526 ΞΌs β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 1.255 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 1.256 ΞΌs Β± 13.999 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
ββ
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β
1.25 ΞΌs Histogram: frequency by time 1.35 ΞΌs <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> vOc = similar(vO);
julia> _Conv1D!(vOc, vA, vB);
julia> conv3!(vO, vA, vB)
julia> vOc β vO
true
This is a modification of the SimpleChains.jl code.
The basic idea is hoist the checks to remove them, so that the main loop doesnβt have them.
I imagine we can probably do better. The main loop looks like
.LBB0_27: # %L708
# Parent Loop BB0_26 Depth=1
# => This Inner Loop Header: Depth=2
vbroadcastsd zmm8, qword ptr [rdx]
vfmadd231pd zmm0, zmm8, zmmword ptr [rdi + 8*rax - 448] # zmm0 = (zmm8 * mem) + zmm0
vfmadd231pd zmm1, zmm8, zmmword ptr [rdi + 8*rax - 384] # zmm1 = (zmm8 * mem) + zmm1
vfmadd231pd zmm2, zmm8, zmmword ptr [rdi + 8*rax - 320] # zmm2 = (zmm8 * mem) + zmm2
vfmadd231pd zmm3, zmm8, zmmword ptr [rdi + 8*rax - 256] # zmm3 = (zmm8 * mem) + zmm3
vfmadd231pd zmm4, zmm8, zmmword ptr [rdi + 8*rax - 192] # zmm4 = (zmm8 * mem) + zmm4
vfmadd231pd zmm5, zmm8, zmmword ptr [rdi + 8*rax - 128] # zmm5 = (zmm8 * mem) + zmm5
vfmadd231pd zmm6, zmm8, zmmword ptr [rdi + 8*rax - 64] # zmm6 = (zmm8 * mem) + zmm6
vfmadd231pd zmm7, zmm8, zmmword ptr [rdi + 8*rax] # zmm7 = (zmm8 * mem) + zmm7
add rdx, 8
dec rax
jne .LBB0_27
So we have 1 broadcast and loads of loads+fmas.
However, itβs easy to see why itβs around twice as fast as the original _Conv1D!
implementation:
for idxA in 1:lenA
vO[idxA + idxB - 1] += vA[idxA] * vB[idxB];
end
is going to have 2 loads per fma. Twice as many loads, twice as slow.