Hi everyone,
First of all, the motivation of this question is the StaticArrays issue 3x3 matrix multiply could potentially be faster.
I showed that matrix multiplication can become substantially faster across a range of sizes:
(size(m3), size(m1), size(m2)) = ((2, 2), (2, 2), (2, 2))
MMatrix Multiplication:
3.135 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
1.584 ns (0 allocations: 0 bytes)
fastmul!:
1.896 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0; 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((3, 3), (3, 3), (3, 3))
MMatrix Multiplication:
7.383 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
8.274 ns (0 allocations: 0 bytes)
fastmul!:
3.237 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((4, 4), (4, 4), (4, 4))
MMatrix Multiplication:
10.524 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
3.528 ns (0 allocations: 0 bytes)
fastmul!:
4.154 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((5, 5), (5, 5), (5, 5))
MMatrix Multiplication:
19.038 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
22.561 ns (0 allocations: 0 bytes)
fastmul!:
5.894 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((6, 6), (6, 6), (6, 6))
MMatrix Multiplication:
30.316 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
39.103 ns (0 allocations: 0 bytes)
fastmul!:
7.837 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((7, 7), (7, 7), (7, 7))
MMatrix Multiplication:
51.105 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
62.420 ns (0 allocations: 0 bytes)
fastmul!:
11.871 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((8, 8), (8, 8), (8, 8))
MMatrix Multiplication:
36.552 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
12.940 ns (0 allocations: 0 bytes)
fastmul!:
12.794 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((9, 9), (9, 9), (9, 9))
MMatrix Multiplication:
68.433 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
73.896 ns (0 allocations: 0 bytes)
fastmul!:
24.042 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((10, 10), (10, 10), (10, 10))
MMatrix Multiplication:
106.568 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
109.546 ns (0 allocations: 0 bytes)
fastmul!:
31.296 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((11, 11), (11, 11), (11, 11))
MMatrix Multiplication:
161.298 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
151.703 ns (0 allocations: 0 bytes)
fastmul!:
38.405 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((12, 12), (12, 12), (12, 12))
MMatrix Multiplication:
210.829 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
216.941 ns (0 allocations: 0 bytes)
fastmul!:
47.986 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((13, 13), (13, 13), (13, 13))
MMatrix Multiplication:
315.835 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
302.625 ns (0 allocations: 0 bytes)
fastmul!:
65.856 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((14, 14), (14, 14), (14, 14))
MMatrix Multiplication:
466.087 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
396.795 ns (0 allocations: 0 bytes)
fastmul!:
66.755 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((15, 14), (15, 15), (15, 14))
MMatrix Multiplication:
548.775 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
450.668 ns (0 allocations: 0 bytes)
fastmul!:
67.804 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
(size(m3), size(m1), size(m2)) = ((16, 14), (16, 16), (16, 14))
MMatrix Multiplication:
545.207 ns (0 allocations: 0 bytes)
SMatrix Multiplication:
83.752 ns (0 allocations: 0 bytes)
fastmul!:
65.457 ns (0 allocations: 0 bytes)
m3 - s3 = [0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0]
We can already use the fastmul!
function as soon as weâre willing to add dependencies to CpuId
and SIMD
/SIMDPirates
â plus a little extra code to loop over the kernel for whenever the matrices are larger than the kernel size (ie, roughly 8xN * Nx6 for avx2 or 16xN * Nx14 for avx512).
However, it only works for the heap allocated MMatrix
, meaning using it forces us to write code like mul!(C, A, B)
, after having taken care to preallocate C
, instead of the simpler C = A * B
. Worse than that, if we want to use automatic differentiation libraries, many of these currently do not support mutating arguments, forcing us to constantly trigger the garbage collector (or switch to potentially slower StaticArray methods).
The simple reason we need MMatrices
instead of SMatrices
is because I donât know how to load a vector from an SMatrix
using llvm
intrinsics. This is necessary so that I can use masked load and store operations. In short, for multiplication to be fast, we need to be able to efficiently vectorize it â and masked load/store operations let us do precisely that, even when the number of rows of A
in C = A * B
isnât a multiple of the computerâs vector width.
But surely, it is possible somehow? For example,
julia> using StaticArrays
julia> a = @SVector randn(24);
julia> b = @SVector randn(24);
julia> @code_llvm hcat(a, b)
; Function hcat
; Location: /home/chriselrod/.julia/packages/StaticArrays/WmJnA/src/linalg.jl:120
define void @julia_hcat_-777564679({ [48 x double] }* noalias nocapture sret, { [24 x double] } addrspace(11)* nocapture nonnull readonly dereferenceable(192), { [24 x double] } addrspace(11)* nocapture nonnull readonly dereferenceable(192)) {
top:
; Function _hcat; {
; Location: /home/chriselrod/.julia/packages/StaticArrays/WmJnA/src/linalg.jl:124
; Function macro expansion; {
; Location: /home/chriselrod/.julia/packages/StaticArrays/WmJnA/src/linalg.jl:135
; Function getindex; {
; Location: /home/chriselrod/.julia/packages/StaticArrays/WmJnA/src/SVector.jl:37
; Function getindex; {
; Location: tuple.jl:24
%3 = getelementptr { [24 x double] }, { [24 x double] } addrspace(11)* %1, i64 0, i32 0, i64 8
%4 = getelementptr { [24 x double] }, { [24 x double] } addrspace(11)* %1, i64 0, i32 0, i64 16
%5 = getelementptr { [24 x double] }, { [24 x double] } addrspace(11)* %2, i64 0, i32 0, i64 8
%6 = getelementptr { [24 x double] }, { [24 x double] } addrspace(11)* %2, i64 0, i32 0, i64 16
;}}
%7 = bitcast { [24 x double] } addrspace(11)* %1 to <8 x i64> addrspace(11)*
%8 = load <8 x i64>, <8 x i64> addrspace(11)* %7, align 8
%9 = bitcast double addrspace(11)* %3 to <8 x i64> addrspace(11)*
%10 = load <8 x i64>, <8 x i64> addrspace(11)* %9, align 8
%11 = bitcast double addrspace(11)* %4 to <8 x i64> addrspace(11)*
%12 = load <8 x i64>, <8 x i64> addrspace(11)* %11, align 8
%13 = bitcast { [24 x double] } addrspace(11)* %2 to <8 x i64> addrspace(11)*
%14 = load <8 x i64>, <8 x i64> addrspace(11)* %13, align 8
%15 = bitcast double addrspace(11)* %5 to <8 x i64> addrspace(11)*
%16 = load <8 x i64>, <8 x i64> addrspace(11)* %15, align 8
%17 = bitcast double addrspace(11)* %6 to <8 x i64> addrspace(11)*
%18 = load <8 x i64>, <8 x i64> addrspace(11)* %17, align 8
;}}
%19 = bitcast { [48 x double] }* %0 to <8 x i64>*
store <8 x i64> %8, <8 x i64>* %19, align 8
%.sroa.0.sroa.9.0..sroa.0.0..sroa_cast1.sroa_idx58 = getelementptr inbounds { [48 x double] }, { [48 x double] }* %0, i64 0, i32 0, i64 8
%20 = bitcast double* %.sroa.0.sroa.9.0..sroa.0.0..sroa_cast1.sroa_idx58 to <8 x i64>*
store <8 x i64> %10, <8 x i64>* %20, align 8
%.sroa.0.sroa.17.0..sroa.0.0..sroa_cast1.sroa_idx66 = getelementptr inbounds { [48 x double] }, { [48 x double] }* %0, i64 0, i32 0, i64 16
%21 = bitcast double* %.sroa.0.sroa.17.0..sroa.0.0..sroa_cast1.sroa_idx66 to <8 x i64>*
store <8 x i64> %12, <8 x i64>* %21, align 8
%.sroa.0.sroa.25.0..sroa.0.0..sroa_cast1.sroa_idx74 = getelementptr inbounds { [48 x double] }, { [48 x double] }* %0, i64 0, i32 0, i64 24
%22 = bitcast double* %.sroa.0.sroa.25.0..sroa.0.0..sroa_cast1.sroa_idx74 to <8 x i64>*
store <8 x i64> %14, <8 x i64>* %22, align 8
%.sroa.0.sroa.33.0..sroa.0.0..sroa_cast1.sroa_idx82 = getelementptr inbounds { [48 x double] }, { [48 x double] }* %0, i64 0, i32 0, i64 32
%23 = bitcast double* %.sroa.0.sroa.33.0..sroa.0.0..sroa_cast1.sroa_idx82 to <8 x i64>*
store <8 x i64> %16, <8 x i64>* %23, align 8
%.sroa.0.sroa.41.0..sroa.0.0..sroa_cast1.sroa_idx90 = getelementptr inbounds { [48 x double] }, { [48 x double] }* %0, i64 0, i32 0, i64 40
%24 = bitcast double* %.sroa.0.sroa.41.0..sroa.0.0..sroa_cast1.sroa_idx90 to <8 x i64>*
store <8 x i64> %18, <8 x i64>* %24, align 8
ret void
}
hcat
seems to do precisely this. For example, it loads elements 9-16 from âaâ with:
%3 = getelementptr { [24 x double] }, { [24 x double] } addrspace(11)* %1, i64 0, i32 0, i64 8
%9 = bitcast double addrspace(11)* %3 to <8 x i64> addrspace(11)*
%10 = load <8 x i64>, <8 x i64> addrspace(11)* %9, align 8
and elements 17-24 from a
with:
%4 = getelementptr { [24 x double] }, { [24 x double] } addrspace(11)* %1, i64 0, i32 0, i64 16
%11 = bitcast double addrspace(11)* %4 to <8 x i64> addrspace(11)*
%12 = load <8 x i64>, <8 x i64> addrspace(11)* %11, align 8
etc.
This doesnât look so hard. Just showing simple code (rather than a generated function):
julia> using StaticArrays
julia> const Vec{N,T} = NTuple{N,Core.VecElement{T}}
Tuple{Vararg{VecElement{T},N}} where T where N
julia> @inline function svload(A::SArray{S,T,SN,L}, i::Int) where {S,T,SN,L}
Base.llvmcall(("",
"""%elptr = getelementptr { [24 x double] }, { [24 x double] }* %0, i64 0, i32 0, i64 %1
%ptr = bitcast double addrspace(11)* %elptr to <8 x double>*
%res = load <8 x double>, <8 x double>* %ptr, align 8
ret <8 x double> %res"""),
Vec{8, Float64}, Tuple{SArray{S, T, SN, L}, Int}, A, i)
end
svload (generic function with 1 method)
julia> a = @SVector randn(24);
julia> svload(a, 7)
ERROR: error compiling svload: Failed to parse LLVM Assembly:
julia: llvmcall:3:62: error: '%0' defined with type '{ [24 x double] }'
%elptr = getelementptr { [24 x double] }, { [24 x double] }* %0, i64 0, i32 0, i64 %1
^
Stacktrace:
[1] top-level scope at none:0
Somehow, when calling hcat
, it interprets the StaticArray as a pointer to elements, but when I try to copy the code, it instead interprets the SArray{S, T, SN, L}
type declaration as '{ [L x T] }'
. Can I instead correctly declare it as '{ [L x T] }*'
, or extract the address to get '{ [L x T] }*'
from '{ [L x T] }'
?
I understand of course that getting a pointer to a stack-allocated struct doesnât make sense within Julia, but (based on the function hcat
), it seems like it ought to work within a brief function?
Once thatâs done, I also have no idea what is going on with %.sroa.0.sroa.9.0..sroa.0.0..sroa_cast1.sroa_idx58
, nor how the function actually returns the resulting SMatrix - ret void
looks like it doesnât return anything at all.
So Iâd also have no idea how to precede there to actually put a resulting StaticArray together.
Unfortunately, I would need the store
operations, because we require the masked stores (for all size(A,1) % VECTOR_WIDTH != 0
).
I think it would be great if we can get such a huge performance boost on top of StaticArrays, even if we have to load another library with a few more dependencies on top of it (that replaces the *
and LinearAlgebra.mul!
methods defined in StaticArrays
).
EDIT:
It looks like inttoptr is one of the things I am looking for.