How to make the most of SIMD.jl when number of data elements is not divisible by SIMD width

Consider the following MWE that simply stores in c an elementwise vector multiplication (so c=a.*b).

Here I want to take SIMD chunks and do the operation in SIMD vectors. But the “naive solution” that I get is around 2x slower than the non SIMD version (there was an error in the code, leaving the error with a comment as a reference). I would like to have one that is at least of the same speed. In this example her we could put bigger vectors (10k is just too small probably) but the result is the same, 2x slower.

using SIMD
using BenchmarkTools
using Random
Random.seed!(1234)

function c_a_times_b!(c::Array{T}, a::Array{T}, b::Array{T}) where T
    @assert length(c) == length(a) == length(b)
    for i in 1:length(c)
        c[i] = a[i]*b[i]
    end
end

function c_a_times_b_SIMD_incorrect!(c::Array{T}, a::Array{T}, b::Array, v_type::Type{Vec{N,T}}) where {N, T}
    @assert length(a) == length(b) == length(c)
    
    @inbounds for i in 1:N:length(a)
        a_chunk = vload(v_type, a, i) 
        b_chunk = vload(v_type, b, i) 
        a_chunk *=  b_chunk
        vstore(a_chunk, c, i)
    end
end

function c_a_times_b_SIMD_correct!(c, a, b)
    N = 8
    T = eltype(c)
    n_elements = length(a)
    n_remaining = mod(n_elements, N)
    n_first = n_elements - n_remaining
    vtype = Vec{N,T}
    
    @inbounds for i in 1:N:n_first
        a_chunk = vload(vtype, a, i) 
        b_chunk = vload(vtype, b, i) 
        a_chunk *=  b_chunk
        vstore(a_chunk,c,i)
    end
    
    @inbounds for i in n_remaining:n_elements # THIS IS AN ERROR should be n_fist: n_elements, leaving it to make post reproducible
        c[i] = a[i]*b[i]
    end
end



T = Float64
M = 10_000
a = Vector{T}(1:M)
b = Vector{T}(1:M)
c = zeros(T,M)

println("no SIMD\t", @benchmark c_a_times_b!($c, $a, $b))
println("SIMD ignoring remainder\t", @benchmark c_a_times_b_SIMD_incorrect!($c, $a, $b, Vec{8,Float64}))
println("SIMD with remainder\t", @benchmark c_a_times_b_SIMD_correct!($c, $a, $b))

a = Vector{T}(1:M+10)
b = Vector{T}(1:M+10)
c = zeros(T,M+10)

println("no SIMD\t",@benchmark c_a_times_b!($c, $a, $b))
println("SIMD ignoring remainder\t", @benchmark c_a_times_b_SIMD_incorrect!($c, $a, $b, Vec{8,Float64}))
println("SIMD with remainder\t", @benchmark c_a_times_b_SIMD_correct!($c, $a, $b))

Which prints

no SIMD	Trial(2.083 μs)
SIMD ignoring remainder	Trial(2.546 μs)
SIMD with remainder	Trial(5.119 μs)
no SIMD	Trial(2.329 μs)
SIMD ignoring remainder	Trial(2.625 μs)
SIMD with remainder	Trial(4.708 μs)

Here we can see that adding a loop to cover the remaining elements of the input arrays makes the method almost 2x slower than the non SIMD version.

Is there “a standard” practise to avoid this ?

I Would like to see different ways to solve this issue. Leveraging libraries wiith macros is a welcome solution but I wonder how to achieve similar performance with the “basic SIMD” intrinsics.

2 Likes

I don’t know what the best option is, but I think you want to do a masked SIMD operation for the remainder, e.g. something like

function c_a_times_b_SIMD_correct!(c, a, b, ::Type{vtype}) where {N, T, vtype<:Vec{N, T}}
    n_elements = length(a)
    n_remaining = mod(n_elements, N)
    n_first = n_elements - n_remaining
    
    @inbounds for i in 1:N:n_first
        a_chunk = vload(vtype, a, i) 
        b_chunk = vload(vtype, b, i) 
        a_chunk *=  b_chunk
        vstore(a_chunk,c,i)
    end
    i_final = n_first + 1
    mask = Vec(ntuple(i -> i <= n_remaining, Val(N)))
    @inbounds begin 
        a_chunk = vload(vtype, a, i_final, mask)
        b_chunk = vload(vtype, b, i_final, mask)
        a_chunk *= b_chunk
        vstore(a_chunk, c, i_final, mask)
    end
end

However, I’m not sure if this is the right way to do it, and I don’t know if inbounds is correctly interacting with the masked operations, so I’m not certain that this won’t cause a segfault or something.

Hopefully someone who knows more can come along and explain.

2 Likes

Firstly, SIMD.jl has an indexing interface which is much nicer than the vload/vstore stuff.

Secondly, I would just shift the last iteration so that it, too, can be simd with the same length. Best to avoid switching between simd and scalar. I mean, don’t skip N samples before the last iteration, then there will be an overlap region where the same multiplication is done twice.

Thirdly, how do you know the original loop isn’t being simd-vectorized by the compiler? Check with @code_llvm for example.

4 Likes

About (1) I have no issues with vload / vstore, to me they are clear and consice, but maybe there is a more performant way to do this?

About (2) I am not sure how to do what you propose about “shift the last iteration so that it, too, can be simd with the same length”, If I vload() at position end - 4 but the length is 8 I would get an error, you mean to load the last end-N even if they happen to be also in the previous iteration of the loop ?

About 3) maybe it is beeing simd-vectorized, but the point of the post is to understand how to manually create “manual SIMD code” in a simple example such as the one proposed.

julia> @code_llvm c_a_times_b!(c, a, b)
; Function Signature: c_a_times_b!(Array{Float64, 1}, Array{Float64, 1}, Array{Float64, 1})
;  @ REPL[5]:1 within `c_a_times_b!`
define nonnull ptr @"japi1_c_a_times_b!_5121"(ptr %"function::Core.Function", ptr noalias nocapture noundef readonly %"args::Any[]", i32 %"nargs::UInt32") #0 {
top:
  %gcframe1 = alloca [3 x ptr], align 16
  call void @llvm.memset.p0.i64(ptr align 16 %gcframe1, i8 0, i64 24, i1 true)
  %stackargs = alloca ptr, align 8
  store volatile ptr %"args::Any[]", ptr %stackargs, align 8
  %"new::Tuple60" = alloca [1 x i64], align 8
  %pgcstack = call ptr inttoptr (i64 6868749296 to ptr)(i64 263) #11
  store i64 4, ptr %gcframe1, align 16
  %task.gcstack = load ptr, ptr %pgcstack, align 8
  %frame.prev = getelementptr inbounds ptr, ptr %gcframe1, i64 1
  store ptr %task.gcstack, ptr %frame.prev, align 8
  store ptr %gcframe1, ptr %pgcstack, align 8
  %0 = load ptr, ptr %"args::Any[]", align 8
  %1 = getelementptr inbounds ptr, ptr %"args::Any[]", i64 1
  %2 = load ptr, ptr %1, align 8
;  @ REPL[5]:2 within `c_a_times_b!`
; ┌ @ essentials.jl:11 within `length`
   %3 = getelementptr inbounds i8, ptr %2, i64 16
   %.size.sroa.0.0.copyload = load i64, ptr %3, align 8
   %4 = getelementptr inbounds i8, ptr %0, i64 16
   %.size3.sroa.0.0.copyload = load i64, ptr %4, align 8
; └
; ┌ @ promotion.jl:639 within `==`
   %.not = icmp eq i64 %.size3.sroa.0.0.copyload, %.size.sroa.0.0.copyload
; └
  br i1 %.not, label %L15, label %L106

L15:                                              ; preds = %top
  %5 = getelementptr inbounds ptr, ptr %"args::Any[]", i64 2
  %6 = load ptr, ptr %5, align 8
; ┌ @ essentials.jl:11 within `length`
   %7 = getelementptr inbounds i8, ptr %6, i64 16
   %.size6.sroa.0.0.copyload = load i64, ptr %7, align 8
; └
; ┌ @ promotion.jl:639 within `==`
   %.not81 = icmp eq i64 %.size.sroa.0.0.copyload, %.size6.sroa.0.0.copyload
; └
  br i1 %.not81, label %L17, label %L106

L17:                                              ; preds = %L15
;  @ REPL[5]:3 within `c_a_times_b!`
; ┌ @ range.jl:5 within `Colon`
; │┌ @ range.jl:408 within `UnitRange`
; ││┌ @ range.jl:419 within `unitrange_last`
     %value_phi13 = call i64 @llvm.smax.i64(i64 %.size.sroa.0.0.copyload, i64 0)
; └└└
; ┌ @ range.jl:904 within `iterate`
; │┌ @ range.jl:681 within `isempty`
; ││┌ @ operators.jl:379 within `>`
; │││┌ @ int.jl:83 within `<`
      %8 = icmp slt i64 %.size.sroa.0.0.copyload, 1
; └└└└
  br i1 %8, label %L105, label %L37.preheader

L37.preheader:                                    ; preds = %L17
  %9 = load ptr, ptr %2, align 8
  %10 = load ptr, ptr %6, align 8
  %11 = load ptr, ptr %0, align 8
;  @ REPL[5]:4 within `c_a_times_b!`
; ┌ @ essentials.jl:916 within `getindex`
   %12 = add nuw i64 %.size.sroa.0.0.copyload, 1
   %13 = add nsw i64 %value_phi13, -1
   %umin87 = call i64 @llvm.umin.i64(i64 %.size.sroa.0.0.copyload, i64 %13)
   %14 = add i64 %umin87, 1
   %min.iters.check = icmp ult i64 %14, 11
   br i1 %min.iters.check, label %scalar.ph, label %vector.memcheck

vector.memcheck:                                  ; preds = %L37.preheader
   %15 = shl i64 %umin87, 3
   %16 = add i64 %15, 8
   %uglygep = getelementptr i8, ptr %11, i64 %16
   %uglygep82 = getelementptr i8, ptr %9, i64 %16
   %uglygep83 = getelementptr i8, ptr %10, i64 %16
   %bound0 = icmp ult ptr %11, %uglygep82
   %bound1 = icmp ult ptr %9, %uglygep
   %found.conflict = and i1 %bound0, %bound1
   %bound084 = icmp ult ptr %11, %uglygep83
   %bound185 = icmp ult ptr %10, %uglygep
   %found.conflict86 = and i1 %bound084, %bound185
   %conflict.rdx = or i1 %found.conflict, %found.conflict86
   br i1 %conflict.rdx, label %scalar.ph, label %vector.ph

vector.ph:                                        ; preds = %vector.memcheck
   %n.mod.vf = and i64 %14, 7
   %17 = icmp eq i64 %n.mod.vf, 0
   %18 = select i1 %17, i64 8, i64 %n.mod.vf
   %n.vec = sub i64 %14, %18
   %ind.end = add i64 %n.vec, 1
   br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
   %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
; │ @ essentials.jl:917 within `getindex`
   %19 = getelementptr inbounds double, ptr %9, i64 %index
   %wide.load = load <2 x double>, ptr %19, align 8
   %20 = getelementptr inbounds double, ptr %19, i64 2
   %wide.load88 = load <2 x double>, ptr %20, align 8
   %21 = getelementptr inbounds double, ptr %19, i64 4
   %wide.load89 = load <2 x double>, ptr %21, align 8
   %22 = getelementptr inbounds double, ptr %19, i64 6
   %wide.load90 = load <2 x double>, ptr %22, align 8
   %23 = getelementptr inbounds double, ptr %10, i64 %index
   %wide.load91 = load <2 x double>, ptr %23, align 8
   %24 = getelementptr inbounds double, ptr %23, i64 2
   %wide.load92 = load <2 x double>, ptr %24, align 8
   %25 = getelementptr inbounds double, ptr %23, i64 4
   %wide.load93 = load <2 x double>, ptr %25, align 8
   %26 = getelementptr inbounds double, ptr %23, i64 6
   %wide.load94 = load <2 x double>, ptr %26, align 8
; └
; ┌ @ float.jl:493 within `*`
   %27 = fmul <2 x double> %wide.load, %wide.load91
   %28 = fmul <2 x double> %wide.load88, %wide.load92
   %29 = fmul <2 x double> %wide.load89, %wide.load93
   %30 = fmul <2 x double> %wide.load90, %wide.load94
; └
; ┌ @ array.jl:976 within `setindex!`
   %31 = getelementptr inbounds double, ptr %11, i64 %index
   store <2 x double> %27, ptr %31, align 8
   %32 = getelementptr inbounds double, ptr %31, i64 2
   store <2 x double> %28, ptr %32, align 8
   %33 = getelementptr inbounds double, ptr %31, i64 4
   store <2 x double> %29, ptr %33, align 8
   %34 = getelementptr inbounds double, ptr %31, i64 6
   store <2 x double> %30, ptr %34, align 8
   %index.next = add nuw i64 %index, 8
   %35 = icmp eq i64 %index.next, %n.vec
   br i1 %35, label %scalar.ph, label %vector.body

scalar.ph:                                        ; preds = %vector.body, %vector.memcheck, %L37.preheader
   %bc.resume.val = phi i64 [ 1, %L37.preheader ], [ 1, %vector.memcheck ], [ %ind.end, %vector.body ]
; └
; ┌ @ essentials.jl:916 within `getindex`
   br label %L37

L37:                                              ; preds = %L90, %scalar.ph
   %value_phi17 = phi i64 [ %43, %L90 ], [ %bc.resume.val, %scalar.ph ]
   %exitcond.not = icmp eq i64 %value_phi17, %12
   br i1 %exitcond.not, label %L50, label %L90

L50:                                              ; preds = %L37
   store i64 %12, ptr %"new::Tuple60", align 8
   call void @j_throw_boundserror_5138(ptr nonnull %2, ptr nocapture nonnull readonly %"new::Tuple60") #6
   unreachable

L90:                                              ; preds = %L37
   %36 = add nsw i64 %value_phi17, -1
; │ @ essentials.jl:917 within `getindex`
   %37 = getelementptr inbounds double, ptr %9, i64 %36
   %38 = load double, ptr %37, align 8
   %39 = getelementptr inbounds double, ptr %10, i64 %36
   %40 = load double, ptr %39, align 8
; └
; ┌ @ float.jl:493 within `*`
   %41 = fmul double %38, %40
; └
; ┌ @ array.jl:976 within `setindex!`
   %42 = getelementptr inbounds double, ptr %11, i64 %36
   store double %41, ptr %42, align 8
; └
;  @ REPL[5]:5 within `c_a_times_b!`
; ┌ @ range.jl:908 within `iterate`
; │┌ @ promotion.jl:639 within `==`
    %.not74.not = icmp eq i64 %value_phi17, %value_phi13
; │└
   %43 = add nuw i64 %value_phi17, 1
; └
  br i1 %.not74.not, label %L105, label %L37

L105:                                             ; preds = %L90, %L17
  %jl_nothing = load ptr, ptr @jl_nothing, align 8
  %frame.prev101 = load ptr, ptr %frame.prev, align 8
  store ptr %frame.prev101, ptr %pgcstack, align 8
  ret ptr %jl_nothing

L106:                                             ; preds = %L15, %top
;  @ REPL[5]:2 within `c_a_times_b!`
  %44 = call [1 x ptr] @j_AssertionError_5140(ptr nonnull @"jl_global#5141.jit")
  %gc_slot_addr_0 = getelementptr inbounds ptr, ptr %gcframe1, i64 2
  %45 = extractvalue [1 x ptr] %44, 0
  store ptr %45, ptr %gc_slot_addr_0, align 16
  %ptls_field = getelementptr inbounds ptr, ptr %pgcstack, i64 2
  %ptls_load = load ptr, ptr %ptls_field, align 8
  %"box::AssertionError" = call noalias nonnull align 8 dereferenceable(16) ptr @ijl_gc_pool_alloc_instrumented(ptr %ptls_load, i32 752, i32 16, i64 5573085232) #9
  %"box::AssertionError.tag_addr" = getelementptr inbounds i64, ptr %"box::AssertionError", i64 -1
  store atomic i64 5573085232, ptr %"box::AssertionError.tag_addr" unordered, align 8
  store ptr %45, ptr %"box::AssertionError", align 8
  call void @ijl_throw(ptr nonnull %"box::AssertionError")
  unreachable
}

The other version

@code_llvm c_a_times_b_SIMD_correct!(c, a, b)
; Function Signature: c_a_times_b_SIMD_correct!(Array{Float64, 1}, Array{Float64, 1}, Array{Float64, 1})
;  @ REPL[8]:1 within `c_a_times_b_SIMD_correct!`
define nonnull ptr @"japi1_c_a_times_b_SIMD_correct!_5238"(ptr %"function::Core.Function", ptr noalias nocapture noundef readonly %"args::Any[]", i32 %"nargs::UInt32") #0 {
L23:
  %stackargs = alloca ptr, align 8
  store volatile ptr %"args::Any[]", ptr %stackargs, align 8
  %0 = load ptr, ptr %"args::Any[]", align 8
  %1 = getelementptr inbounds ptr, ptr %"args::Any[]", i64 1
  %2 = load ptr, ptr %1, align 8
  %3 = getelementptr inbounds ptr, ptr %"args::Any[]", i64 2
  %4 = load ptr, ptr %3, align 8
;  @ REPL[8]:4 within `c_a_times_b_SIMD_correct!`
; ┌ @ essentials.jl:11 within `length`
   %5 = getelementptr inbounds i8, ptr %2, i64 16
   %6 = load i64, ptr %5, align 8
; └
;  @ REPL[8]:5 within `c_a_times_b_SIMD_correct!`
; ┌ @ int.jl:287 within `mod`
; │┌ @ div.jl:321 within `fld`
; ││┌ @ div.jl:355 within `div` @ div.jl:310 @ int.jl:295
     %7 = sdiv i64 %6, 4
; │││ @ div.jl:356 within `div`
; │││┌ @ int.jl:139 within `signbit`
; ││││┌ @ int.jl:83 within `<`
       %8 = icmp slt i64 %6, 0
; │││└└
; │││┌ @ int.jl:88 within `*`
      %9 = shl nsw i64 %7, 2
; │││└
; │││┌ @ operators.jl:277 within `!=`
; ││││┌ @ promotion.jl:639 within `==`
       %10 = icmp ne i64 %9, %6
; │││└└
; │││┌ @ bool.jl:38 within `&`
      %11 = and i1 %8, %10
; │││└
; │││┌ @ int.jl:1011 within `-`
; ││││┌ @ int.jl:546 within `rem`
; │││││┌ @ number.jl:7 within `convert`
; ││││││┌ @ boot.jl:892 within `Int64`
; │││││││┌ @ boot.jl:819 within `toInt64`
          %.neg = sext i1 %11 to i64
; ││││└└└└
; ││││ @ int.jl:1013 within `-` @ int.jl:86
      %12 = add nsw i64 %7, %.neg
; │└└└
; │┌ @ int.jl:88 within `*`
    %13 = shl i64 %12, 2
; │└
; │┌ @ int.jl:86 within `-`
    %14 = sub i64 %6, %13
; └└
;  @ REPL[8]:9 within `c_a_times_b_SIMD_correct!`
; ┌ @ range.jl:22 within `Colon`
; │┌ @ range.jl:24 within `_colon`
; ││┌ @ range.jl:384 within `StepRange` @ range.jl:329
; │││┌ @ range.jl:347 within `steprange_last`
; ││││┌ @ operators.jl:379 within `>`
; │││││┌ @ int.jl:83 within `<`
        %15 = icmp sgt i64 %13, 1
; ││││└└
      br i1 %15, label %L56, label %L231

L56:                                              ; preds = %L23
; ││││ @ range.jl:363 within `steprange_last`
      %value_phi208 = add nsw i64 %13, -3
; └└└└
; ┌ @ range.jl:904 within `iterate`
; │┌ @ range.jl:678 within `isempty`
; ││┌ @ bool.jl:38 within `&`
     %16 = icmp slt i64 %13, 4
; └└└
  br i1 %16, label %L231, label %L70

L70:                                              ; preds = %L70, %L56
  %value_phi21 = phi i64 [ %26, %L70 ], [ 1, %L56 ]
;  @ REPL[8]:10 within `c_a_times_b_SIMD_correct!`
; ┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 within `vload` @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:60
; │┌ @ abstractarray.jl:1232 within `pointer` @ abstractarray.jl:1229
; ││┌ @ pointer.jl:65 within `cconvert`
     %17 = load ptr, ptr %2, align 8
; ││└
; ││ @ abstractarray.jl:1232 within `pointer`
; ││┌ @ abstractarray.jl:1236 within `_memory_offset`
; │││┌ @ int.jl:88 within `*`
      %18 = shl i64 %value_phi21, 3
      %19 = add nsw i64 %18, -8
; ││└└
; ││┌ @ pointer.jl:316 within `+`
     %20 = getelementptr i8, ptr %17, i64 %19
; │└└
; │ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 within `vload` @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:61 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:50
; │┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:470 within `load`
; ││┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:479 within `macro expansion`
     %res.i = load <4 x double>, ptr %20, align 8
; └└└
;  @ REPL[8]:11 within `c_a_times_b_SIMD_correct!`
; ┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 within `vload` @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:60
; │┌ @ abstractarray.jl:1232 within `pointer` @ abstractarray.jl:1229
; ││┌ @ pointer.jl:65 within `cconvert`
     %21 = load ptr, ptr %4, align 8
; ││└
; ││ @ abstractarray.jl:1232 within `pointer`
; ││┌ @ pointer.jl:316 within `+`
     %22 = getelementptr i8, ptr %21, i64 %19
; │└└
; │ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 within `vload` @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:58 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:61 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:50
; │┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:470 within `load`
; ││┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:479 within `macro expansion`
     %res.i221 = load <4 x double>, ptr %22, align 8
; └└└
;  @ REPL[8]:12 within `c_a_times_b_SIMD_correct!`
; ┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/simdvec.jl:264 within `mul_fast`
; │┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:220 within `fmul`
; ││┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:229 within `macro expansion`
     %23 = fmul fast <4 x double> %res.i221, %res.i
; └└└
;  @ REPL[8]:13 within `c_a_times_b_SIMD_correct!`
; ┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:80 within `vstore` @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:80 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:80 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:82
; │┌ @ abstractarray.jl:1232 within `pointer` @ abstractarray.jl:1229
; ││┌ @ pointer.jl:65 within `cconvert`
     %24 = load ptr, ptr %0, align 8
; ││└
; ││ @ abstractarray.jl:1232 within `pointer`
; ││┌ @ pointer.jl:316 within `+`
     %25 = getelementptr i8, ptr %24, i64 %19
; │└└
; │ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:80 within `vstore` @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:80 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:80 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:83 @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/arrayops.jl:73
; │┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:505 within `store`
; ││┌ @ /Users/dbuchaca/.julia/packages/SIMD/cST3l/src/LLVM_intrinsics.jl:515 within `macro expansion`
     store <4 x double> %23, ptr %25, align 8
; └└└
;  @ REPL[8]:14 within `c_a_times_b_SIMD_correct!`
; ┌ @ range.jl:908 within `iterate`
; │┌ @ promotion.jl:639 within `==`
    %.not.not = icmp eq i64 %value_phi21, %value_phi208
; │└
   %26 = add i64 %value_phi21, 4
; └
  br i1 %.not.not, label %L231, label %L70

L231:                                             ; preds = %L70, %L56, %L23
;  @ REPL[8]:16 within `c_a_times_b_SIMD_correct!`
; ┌ @ range.jl:5 within `Colon`
; │┌ @ range.jl:408 within `UnitRange`
; ││┌ @ range.jl:419 within `unitrange_last`
; │││┌ @ operators.jl:426 within `>=`
; ││││┌ @ int.jl:514 within `<=`
       %.not = icmp sgt i64 %14, %6
; │││└└
     br i1 %.not, label %L316, label %L248.preheader

L248.preheader:                                   ; preds = %L231
     %27 = load ptr, ptr %2, align 8
     %28 = load ptr, ptr %4, align 8
     %29 = load ptr, ptr %0, align 8
; └└└
;  @ REPL[8]:19 within `c_a_times_b_SIMD_correct!`
  %min.iters.check = icmp ult i64 %13, 10
  br i1 %min.iters.check, label %scalar.ph, label %vector.memcheck

vector.memcheck:                                  ; preds = %L248.preheader
  %30 = shl i64 %6, 3
  %31 = add i64 %30, -8
  %32 = shl i64 %12, 5
  %33 = sub i64 %31, %32
  %uglygep = getelementptr i8, ptr %29, i64 %33
  %uglygep227 = getelementptr i8, ptr %29, i64 %30
  %uglygep228 = getelementptr i8, ptr %27, i64 %33
  %uglygep229 = getelementptr i8, ptr %27, i64 %30
  %uglygep230 = getelementptr i8, ptr %28, i64 %33
  %uglygep231 = getelementptr i8, ptr %28, i64 %30
  %bound0 = icmp ult ptr %uglygep, %uglygep229
  %bound1 = icmp ult ptr %uglygep228, %uglygep227
  %found.conflict = and i1 %bound0, %bound1
  %bound0232 = icmp ult ptr %uglygep, %uglygep231
  %bound1233 = icmp ult ptr %uglygep230, %uglygep227
  %found.conflict234 = and i1 %bound0232, %bound1233
  %conflict.rdx = or i1 %found.conflict, %found.conflict234
  br i1 %conflict.rdx, label %scalar.ph, label %vector.ph

vector.ph:                                        ; preds = %vector.memcheck
  %n.vec = and i64 %13, -8
  %ind.end = add i64 %14, %n.vec
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %offset.idx = add i64 %14, %index
;  @ REPL[8]:18 within `c_a_times_b_SIMD_correct!`
; ┌ @ essentials.jl:917 within `getindex`
   %34 = add i64 %offset.idx, -1
   %35 = getelementptr inbounds double, ptr %27, i64 %34
   %wide.load = load <2 x double>, ptr %35, align 8
   %36 = getelementptr inbounds double, ptr %35, i64 2
   %wide.load235 = load <2 x double>, ptr %36, align 8
   %37 = getelementptr inbounds double, ptr %35, i64 4
   %wide.load236 = load <2 x double>, ptr %37, align 8
   %38 = getelementptr inbounds double, ptr %35, i64 6
   %wide.load237 = load <2 x double>, ptr %38, align 8
   %39 = getelementptr inbounds double, ptr %28, i64 %34
   %wide.load238 = load <2 x double>, ptr %39, align 8
   %40 = getelementptr inbounds double, ptr %39, i64 2
   %wide.load239 = load <2 x double>, ptr %40, align 8
   %41 = getelementptr inbounds double, ptr %39, i64 4
   %wide.load240 = load <2 x double>, ptr %41, align 8
   %42 = getelementptr inbounds double, ptr %39, i64 6
   %wide.load241 = load <2 x double>, ptr %42, align 8
; └
; ┌ @ float.jl:493 within `*`
   %43 = fmul <2 x double> %wide.load, %wide.load238
   %44 = fmul <2 x double> %wide.load235, %wide.load239
   %45 = fmul <2 x double> %wide.load236, %wide.load240
   %46 = fmul <2 x double> %wide.load237, %wide.load241
; └
; ┌ @ array.jl:976 within `setindex!`
   %47 = getelementptr inbounds double, ptr %29, i64 %34
   store <2 x double> %43, ptr %47, align 8
   %48 = getelementptr inbounds double, ptr %47, i64 2
   store <2 x double> %44, ptr %48, align 8
   %49 = getelementptr inbounds double, ptr %47, i64 4
   store <2 x double> %45, ptr %49, align 8
   %50 = getelementptr inbounds double, ptr %47, i64 6
   store <2 x double> %46, ptr %50, align 8
   %index.next = add nuw i64 %index, 8
   %51 = icmp eq i64 %index.next, %n.vec
   br i1 %51, label %scalar.ph, label %vector.body

scalar.ph:                                        ; preds = %vector.body, %vector.memcheck, %L248.preheader
   %bc.resume.val = phi i64 [ %14, %L248.preheader ], [ %14, %vector.memcheck ], [ %ind.end, %vector.body ]
; └
;  @ REPL[8]:19 within `c_a_times_b_SIMD_correct!`
  br label %L248

L248:                                             ; preds = %L248, %scalar.ph
  %value_phi131 = phi i64 [ %59, %L248 ], [ %bc.resume.val, %scalar.ph ]
;  @ REPL[8]:18 within `c_a_times_b_SIMD_correct!`
; ┌ @ essentials.jl:917 within `getindex`
   %52 = add i64 %value_phi131, -1
   %53 = getelementptr inbounds double, ptr %27, i64 %52
   %54 = load double, ptr %53, align 8
   %55 = getelementptr inbounds double, ptr %28, i64 %52
   %56 = load double, ptr %55, align 8
; └
; ┌ @ float.jl:493 within `*`
   %57 = fmul double %54, %56
; └
; ┌ @ array.jl:976 within `setindex!`
   %58 = getelementptr inbounds double, ptr %29, i64 %52
   store double %57, ptr %58, align 8
; └
;  @ REPL[8]:19 within `c_a_times_b_SIMD_correct!`
; ┌ @ range.jl:908 within `iterate`
; │┌ @ promotion.jl:639 within `==`
    %.not223.not = icmp eq i64 %value_phi131, %6
; │└
   %59 = add i64 %value_phi131, 1
; └
  br i1 %.not223.not, label %L316, label %L248

L316:                                             ; preds = %L248, %L231
  %jl_nothing = load ptr, ptr @jl_nothing, align 8
  ret ptr %jl_nothing
}
1 Like

I think I figured it out, there was an error on my code and the loop of the remaining was actually computing almost all coordinates. Fixing this the runtimes are more reasonable. I am giving better names to the methods to make it more readable.

Also making the function parametrizable with the simd width (c_a_times_b_SIMD_handmade_parametrized ) makes the method faster than c_a_times_b_SIMD_handmade!.

I guess one could use metaprograming to, given the SIMD width and length of the array, unroll the loop from n_remaining:n_elements. Anyone know examples along those lines? or in general, examples of generated code to look at?

The updated MWE from below produces

Length divisible by SIMD width
c_a_times_b!	Trial(313.750 μs)
c_a_times_b_SIMD!	Trial(309.250 μs)
c_a_times_b_SIMD_handmade!	Trial(314.667 μs)
c_a_times_b_SIMD_handmade_parametrized!	Trial(300.458 μs)

Length NOT divisible by SIMD width
c_a_times_b!	Trial(306.291 μs)
c_a_times_b_SIMD!	Trial(307.542 μs)
c_a_times_b_SIMD_handmade!	Trial(324.708 μs)
c_a_times_b_SIMD_handmade_parametrized!	Trial(303.625 μs)

Updated MWE


using SIMD
using BenchmarkTools
using Random
Random.seed!(1234)

function c_a_times_b!(c::Array{T}, a::Array{T}, b::Array{T}) where T
    @assert length(c) == length(a) == length(b)
    for i in 1:length(c)
        c[i] = a[i]*b[i]
    end
end

function c_a_times_b_SIMD!(c::Array{T}, a::Array{T}, b::Array{T}) where T
    @assert length(c) == length(a) == length(b)
    @simd for i in 1:length(c)
        c[i] = a[i]*b[i]
    end
end


function c_a_times_b_SIMD_handmade!(c, a, b)
    N = 4
    T = eltype(c)
    n_elements = length(a)
    n_remaining = mod(n_elements, N)
    n_first = n_elements - n_remaining
    vtype = Vec{N,T}
    
    @inbounds @fastmath for i in 1:N:n_first
        a_chunk = vload(vtype, a, i) 
        b_chunk = vload(vtype, b, i) 
        a_chunk *=  b_chunk
        vstore(a_chunk,c,i)
    end
    
    @inbounds for i in (n_first+1):n_elements
        c[i] = a[i]*b[i]
    end
end

function c_a_times_b_SIMD_handmade_parametrized!(c, a, b, ::Val{W}) where {W}
    T = eltype(c)
    vtype = Vec{W, T}
    len_a = length(a)
    i_last = len_a - rem(len_a, W)

    @inbounds @simd for i in 1:W:i_last
        va = vload(vtype, a, i)
        vb = vload(vtype, b, i)
        vc = va * vb
        vstore(vc, c, i)
    end

    @inbounds for i in (i_last + 1):len_a
        c[i] = a[i] * b[i]
    end
end


T = Float64
M = 1_000_000
a = Vector{T}(1:M)
b = Vector{T}(1:M)
c = zeros(T,M)
println("Length divisible by SIMD width")
println("c_a_times_b!\t", @benchmark c_a_times_b!($c, $a, $b))
println("c_a_times_b_SIMD!\t", @benchmark c_a_times_b_SIMD!($c, $a, $b))
println("c_a_times_b_SIMD_handmade!\t", @benchmark c_a_times_b_SIMD_handmade!($c, $a, $b))
println("c_a_times_b_SIMD_handmade_parametrized!\t", @benchmark c_a_times_b_SIMD_handmade_parametrized!($c, $a, $b, Val(8)))

a = Vector{T}(1:M+10)
b = Vector{T}(1:M+10)
c = zeros(T,M+10)

println("\nLength NOT divisible by SIMD width")
println("c_a_times_b!\t", @benchmark c_a_times_b!($c, $a, $b))
println("c_a_times_b_SIMD!\t", @benchmark c_a_times_b_SIMD!($c, $a, $b))
println("c_a_times_b_SIMD_handmade!\t", @benchmark c_a_times_b_SIMD_handmade!($c, $a, $b))
println("c_a_times_b_SIMD_handmade_parametrized!\t", @benchmark c_a_times_b_SIMD_handmade_parametrized!($c, $a, $b, Val(8)))
1 Like

It’s still much nicer to be able to write

c[i] = a[i] * b[i]

instead, which you can do if i is a VecRange.

Exactly. Switching between simd and scalar operations can incur a performance cost.

Sure, but when comparing performance with the ‘naive’ loop, it’s good to be aware that the naive version might actually be quite optimized.

2 Likes