Is it ok to iterate inside a `@simd` loop?

I’m delightfully surprised that this works:

function mr(f, op, A, n)
    a1, s = iterate(A)
    a2, s = iterate(A, s)
    v = op(f(a1), f(a2))
    @simd for _ in 3:n
        ai, s = iterate(A, s)
        v = op(v, f(ai))
    end
    return v
end

This seems to break @simd’s rules in two ways:

  • s depends upon prior iterations of the loop
  • destructuring the result of iterate introduces an (unhandled and implicit) error branch

And yet it works! It spits out SIMD code for non-stateful iterators and straightline loops for stateful ones. In fact, it has exactly the same performance as the @inbounds’ed getindex version. What magic is this? Can I count on it? Interestingly, this is even more robust than the iterable flavor of @simd for a in A.

julia> using BenchmarkTools

julia> function mr(f, op, A, n)
           a1, s = iterate(A)
           a2, s = iterate(A, s)
           v = op(f(a1), f(a2))
           @simd for _ in 3:n
               ai, s = iterate(A, s)
               v = op(v, f(ai))
           end
           return v
       end
mr (generic function with 1 method)

julia> A = rand(10000); A32 = rand(Float32, 20000);

julia> @btime mr(identity, +, $A, length($A))
  1.142 μs (0 allocations: 0 bytes)
4978.86295198483

julia> @btime mr(identity, +, $A32, length($A32))
  1.146 μs (0 allocations: 0 bytes)
10014.671f0

julia> @btime mr(identity, +, $(a for a in A), length($A))
  1.146 μs (0 allocations: 0 bytes)
4978.86295198483

julia> const B = A; const B32 = A32;

julia> @btime mr(identity, +, $(B[i] for i in collect(eachindex(A))), length($A))
  9.208 μs (0 allocations: 0 bytes)
4978.862951984815

julia> @btime mr(identity, +, $(B32[i] for i in collect(eachindex(A32))), length($A32))
  18.541 μs (0 allocations: 0 bytes)
10014.671f0

julia> open("A.txt", "w") do f
           for a in A
               println(f, a)
           end
       end

julia> mr(x->parse(Float64, x), +, eachline("A.txt"), length(A))
4978.862951984815

julia> function mr2(f, op, A, n)
           a1 = @inbounds A[1]
           a2 = @inbounds A[2]
           v = op(f(a1), f(a2))
           @simd for i in 3:n
               ai = @inbounds A[i]
               v = op(v, f(ai))
           end
           return v
       end
mr2 (generic function with 1 method)

julia> @btime mr2(identity, +, $A, length($A))
  1.146 μs (0 allocations: 0 bytes)
4978.86295198483
6 Likes

I don’t know what’s happening. Maybe the compiler can detect and optimize away prior iteration dependency? However, this might not be safe if the reduce operation is not commutative and associative.

I would use Transducer’s map and fold for this case. That seems to be a more robust option than doing something like this.

The mr method is specialized for +, and iterate is inlined, and iterate is really simple for vectors. So it’s not very hard for the compiler to recognize that the loop can in fact be run in parallel. For the more complicated generators you see that it’s not nearly as fast.

I don’t think this violates the rules of @simd, only @simd ivdep. The requirement for plain @simd is

  • It is safe to execute iterations in arbitrary or overlapping order, with special consideration for reduction variables.

which is trivially satisfied by your example because your code is invariant under permutations of the iterations.

If the loop-carried dependency due to s is non-trivial (i.e., the iterator is stateful), @simd will probably not be able to vectorize the loop, but it’s still safe to use as long as you don’t add ivdep.

Yeah, I get that iterate is really basic for Array, and I can kinda/sorta see the justification that s doesn’t really introduce a loop dependency after inlining. But I am still amazed that the error branch magically disappears from the loop. It’s still there, though! Somehow Julia figures out that I didn’t really want to iterate over 3:n, I really wanted to have min(n, length(A))-2 iterations. And then after executing those iterations, it checks to see if it needs to throw an error. Very clever. Julia v1.10 doesn’t do this, but 1.11 does.

@code_llvm output

Things are a little out of order — guard_exit90 is what runs first and sets up the loop, setting %20 to, effectively, max(n, 2) -2.

julia> @code_llvm debuginfo=:none mr(identity, +, rand(Float32, 10), 10)
; Function Signature: mr(typeof(Base.identity), typeof(Base.:(+)), Array{Float32, 1}, Int64)
define float @julia_mr_3653(ptr noundef nonnull align 8 dereferenceable(24) %"A::Array", i64 signext %"n::Int64") local_unnamed_addr #0 {
top:
  %"A::Array.size_ptr" = getelementptr inbounds i8, ptr %"A::Array", i64 16
  %"A::Array.size.0.copyload" = load i64, ptr %"A::Array.size_ptr", align 8
  switch i64 %"A::Array.size.0.copyload", label %guard_exit90 [
    i64 0, label %L37
    i64 1, label %L87
  ]

L37:                                              ; preds = %top
  call void @j_indexed_iterate_3655(i64 signext 1) #7
  unreachable

L87:                                              ; preds = %top
  call void @j_indexed_iterate_3655(i64 signext 1) #7
  unreachable

L122.lr.ph:                                       ; preds = %guard_exit90
  %0 = add nsw i64 %"A::Array.size.0.copyload", -2
  %isnotneg.inv = icmp slt i64 %"A::Array.size.0.copyload", 0
  %1 = select i1 %isnotneg.inv, i64 0, i64 %0
  %smin155 = call i64 @llvm.smin.i64(i64 %1, i64 %20)
  %.not = icmp slt i64 %smin155, 1
  br i1 %.not, label %postloop, label %L122.preheader

L122.preheader:                                   ; preds = %L122.lr.ph
  %2 = add nuw i64 %smin155, 3
  %invariant.gep = getelementptr i8, ptr %memoryref_data, i64 -4
  %min.iters.check = icmp ult i64 %smin155, 16
  br i1 %min.iters.check, label %L122, label %vector.ph

vector.ph:                                        ; preds = %L122.preheader
  %n.vec = and i64 %smin155, 9223372036854775792
  %ind.end = or disjoint i64 %n.vec, 3
  %3 = insertelement <4 x float> <float poison, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, float %19, i64 0
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %vec.phi = phi <4 x float> [ %3, %vector.ph ], [ %8, %vector.body ]
  %vec.phi169 = phi <4 x float> [ <float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %vector.ph ], [ %9, %vector.body ]
  %vec.phi170 = phi <4 x float> [ <float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %vector.ph ], [ %10, %vector.body ]
  %vec.phi171 = phi <4 x float> [ <float -0.000000e+00, float -0.000000e+00, float -0.000000e+00, float -0.000000e+00>, %vector.ph ], [ %11, %vector.body ]
  %offset.idx = or disjoint i64 %index, 3
  %4 = getelementptr float, ptr %invariant.gep, i64 %offset.idx
  %5 = getelementptr i8, ptr %4, i64 16
  %6 = getelementptr i8, ptr %4, i64 32
  %7 = getelementptr i8, ptr %4, i64 48
  %wide.load = load <4 x float>, ptr %4, align 4
  %wide.load172 = load <4 x float>, ptr %5, align 4
  %wide.load173 = load <4 x float>, ptr %6, align 4
  %wide.load174 = load <4 x float>, ptr %7, align 4
  %8 = fadd reassoc contract <4 x float> %vec.phi, %wide.load
  %9 = fadd reassoc contract <4 x float> %vec.phi169, %wide.load172
  %10 = fadd reassoc contract <4 x float> %vec.phi170, %wide.load173
  %11 = fadd reassoc contract <4 x float> %vec.phi171, %wide.load174
  %index.next = add nuw i64 %index, 16
  %12 = icmp eq i64 %index.next, %n.vec
  br i1 %12, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %bin.rdx = fadd reassoc contract <4 x float> %9, %8
  %bin.rdx175 = fadd reassoc contract <4 x float> %10, %bin.rdx
  %bin.rdx176 = fadd reassoc contract <4 x float> %11, %bin.rdx175
  %13 = call reassoc contract float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> %bin.rdx176)
  %cmp.n = icmp eq i64 %smin155, %n.vec
  br i1 %cmp.n, label %main.exit.selector, label %L122

L122:                                             ; preds = %L122, %middle.block, %L122.preheader
  %value_phi36152 = phi i64 [ %15, %L122 ], [ 3, %L122.preheader ], [ %ind.end, %middle.block ]
  %value_phi35151 = phi float [ %16, %L122 ], [ %19, %L122.preheader ], [ %13, %middle.block ]
  %gep = getelementptr float, ptr %invariant.gep, i64 %value_phi36152
  %14 = load float, ptr %gep, align 4
  %15 = add nuw i64 %value_phi36152, 1
  %16 = fadd reassoc contract float %value_phi35151, %14
  %exitcond.not = icmp eq i64 %15, %2
  br i1 %exitcond.not, label %main.exit.selector, label %L122

L176:                                             ; preds = %postloop
  call void @j_indexed_iterate_3655(i64 signext 1) #7
  unreachable

L195:                                             ; preds = %L122.postloop, %main.exit.selector, %guard_exit90
  %value_phi65 = phi float [ %19, %guard_exit90 ], [ %.lcssa168, %main.exit.selector ], [ %32, %L122.postloop ]
  ret float %value_phi65

guard_exit90:                                     ; preds = %top
  %memoryref_data = load ptr, ptr %"A::Array", align 8
  %17 = load float, ptr %memoryref_data, align 4
  %memoryref_data24 = getelementptr inbounds i8, ptr %memoryref_data, i64 4
  %18 = load float, ptr %memoryref_data24, align 4
  %19 = fadd float %17, %18
  %value_phi34 = call i64 @llvm.smax.i64(i64 %"n::Int64", i64 2)
  %20 = add nsw i64 %value_phi34, -2
  %.not150 = icmp sgt i64 %"n::Int64", 2
  br i1 %.not150, label %L122.lr.ph, label %L195

main.exit.selector:                               ; preds = %L122, %middle.block
  %.lcssa168 = phi float [ %13, %middle.block ], [ %16, %L122 ]
  %21 = icmp slt i64 %1, %20
  br i1 %21, label %postloop, label %L195

postloop:                                         ; preds = %main.exit.selector, %L122.lr.ph
  %value_phi37153.copy = phi i64 [ 0, %L122.lr.ph ], [ %smin155, %main.exit.selector ]
  %value_phi36152.copy = phi i64 [ 3, %L122.lr.ph ], [ %2, %main.exit.selector ]
  %value_phi35151.copy = phi float [ %19, %L122.lr.ph ], [ %.lcssa168, %main.exit.selector ]
  %22 = add i64 %value_phi36152.copy, -1
  %umax = call i64 @llvm.umax.i64(i64 %"A::Array.size.0.copyload", i64 %22)
  %23 = add i64 %umax, 1
  %24 = sub i64 %23, %value_phi36152.copy
  %25 = add nuw nsw i64 %value_phi37153.copy, 1
  %smax163 = call i64 @llvm.smax.i64(i64 %20, i64 %25)
  %26 = xor i64 %value_phi37153.copy, -1
  %27 = add nsw i64 %smax163, %26
  %28 = freeze i64 %27
  %.not164.not = icmp ugt i64 %24, %28
  br i1 %.not164.not, label %L122.postloop.preheader, label %L176

L122.postloop.preheader:                          ; preds = %postloop
  %invariant.gep166 = getelementptr i8, ptr %memoryref_data, i64 -4
  br label %L122.postloop

L122.postloop:                                    ; preds = %L122.postloop, %L122.postloop.preheader
  %value_phi37153.postloop = phi i64 [ %29, %L122.postloop ], [ %value_phi37153.copy, %L122.postloop.preheader ]
  %value_phi36152.postloop = phi i64 [ %31, %L122.postloop ], [ %value_phi36152.copy, %L122.postloop.preheader ]
  %value_phi35151.postloop = phi float [ %32, %L122.postloop ], [ %value_phi35151.copy, %L122.postloop.preheader ]
  %29 = add nuw nsw i64 %value_phi37153.postloop, 1
  %gep167 = getelementptr float, ptr %invariant.gep166, i64 %value_phi36152.postloop
  %30 = load float, ptr %gep167, align 4
  %31 = add i64 %value_phi36152.postloop, 1
  %32 = fadd reassoc contract float %value_phi35151.postloop, %30
  %.not.postloop = icmp slt i64 %29, %20
  br i1 %.not.postloop, label %L122.postloop, label %L195
}
1 Like

Are you sure that it’s julia that’s doing the heavy lifting here, and not some combination of LLVM & pushing more information to it from Julia? @simd internally is relatively “dumb”, it just introduces an additional loop and tells the compiler to add some extra loop annotations to the LLVM IR, as far as I can tell. The requirements for @simd then just follow from the semantics LLVM requires for valid uses of those annotations.

No, I’m not meaning to imply credit or blame to any particular portion of the language or compiler. I just mean Julia v1.11 has these behaviors by whatever magic.

Ah, yeah, I’m also impressed by the optimizations that make vectorization successful in this case.

I just meant to point out that you’re not violating @simds rules in the sense that your use of it is illegal/unsafe/inconsistent. In other words, at face value, your code satisfies all the properties required for @simd to be safe/consistent, just not all those required for it to be useful/successful.