Compiling to branch table

Is there a way to coax the compiler into compiling something like

const CODES = (2, 3, 5, 7, 11)

function findcode(x)
    index = findfirst(isequal(x), CODES)
    index ≡ nothing ? 0 : index
end

similarly to

function findcode2(x)
    if x == 2
        1
    elseif x == 3
        2
    elseif x == 5
        3
    elseif x == 7
        4
    elseif x == 11
        5
    else
        0
    end
end

other than direct code generation (eg a macro)?

I don’t know of anything except a macro. What will your codes consist of? Is this performance critical? You could probably gain some performance by a) customized binary search instead of a series of ifs and b) sorting your checks by likelihood of occurrence (e.g. if 7 is the most frequent value, then check that first).

https://github.com/JuliaCollections/Memoize.jl?

That’s unlikely (only showing relevant part):

julia> @code_llvm findcode2(3)

; Function findcode2
; Location: REPL[15]:2
define i64 @julia_findcode2_1216615403(i64) {
top:
  switch i64 %0, label %L13 [
    i64 2, label %L3
    i64 3, label %L3.fold.split
    i64 5, label %L3.fold.split1
    i64 7, label %L3.fold.split2
  ]

How large is codes (how many elements) and how large is x (number of bits)?

For example, if x is Int8, then you can compare it to a 16 element codes in a single low-latency instruction (need to broadcast x into a register first). This will end up much faster (and branchfree).

Does anyone know the correct idiom to get llvm to emit the equivalent of _mm_movemask_epi8, without llvmcall into @llvm.x86.avx.movmsk?

Here’s a simple recursive hack: (edit: cleaned up the code a bit)

using StaticNumbers

const CODES = (2, 3, 5, 7, 11)
const SRCODES = static.(reverse(CODES))

f(x) = 0
f(x, y, ys...) = x == y ? 1 + length(ys) : f(x, ys...)
findcode3(x) = f(x, SRCODES...)

@code_llvm findcode3(3)

Results in (only showing the switch statement):

define i64 @julia_findcode3_35284(i64) {
top:
  switch i64 %0, label %L13 [
    i64 11, label %L25
    i64 7, label %L25.fold.split
    i64 5, label %L25.fold.split5
    i64 3, label %L25.fold.split6
  ]

The function static turns the CODES into compile-time constants, which results in the switch statement being compiled specifically for this list. Edit: Let me rephrase that. They’re already compile-time constants, but the compiler may or may not treat them as constant when compiling f. In fact, after I re-wrote the code to make findcode3 pass the list instead of doing it from the REPL, constant propagation kicks in, and StaticNumbers is no longer needed.

Constant propagation works amazingly well in Julia, as I found out when I tried calling f from within a function. There’s no need for StaticNumbers here. The below is enough.

const CODES = (2, 3, 5, 7, 11)
const RCODES = reverse(CODES)

f(x) = 0
f(x, y, ys...) = x == y ? 1 + length(ys) : f(x, ys...)
findcode3(x) = f(x, RCODES...)

@code_llvm findcode3(3)

I don’t think it is. What you have there is the LLVM bitcode. As far as I know, once that code is converted to assembly, the compiler is not going to to turn it into a simple jump table, but a bunch of comparing and branching. (I’m not even sure if a jump table would be better since it might mess with branch prediction.) Hard-coding a binary search will give you better performance then. Example:

@noinline function findcode2(x)
    if x == 2
        1
    elseif x == 3
        2
    elseif x == 5
        3
    elseif x == 7
        4
    elseif x == 11
        5
    else
        0
    end
end

@noinline function findcode_binary(x)
    if x < 5
        x == 2 && return 1
        x == 3 && return 2
    else
        x == 5 && return 3
        x == 7 && return 4
    end
    return 0
end

Timings:

julia> @btime (for k = 1:1000; for n = 1:8; findcode2(n); end; end)
  15.194 μs (0 allocations: 0 bytes)

julia> @btime (for k = 1:1000; for n = 1:8; findcode_binary(n); end; end)
  10.840 μs (0 allocations: 0 bytes)

Even for 5 values, this approach is more efficient. If the table size grows, I’d expect the difference to be bigger.

A lookup-table would be another option. However, it will require a memory access, and won’t be branch free anyway if you need to do bounds checks (default value).

If you want to go the vector route and are happy with x86 specific code, here you go:

using SIMD

@inline function cmp_16xi8_i32(v1::Vec{16, Int8}, v2::Vec{16, Int8})
       Base.llvmcall(("declare i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8>) nounwind readnone", "
       %cmpres = icmp eq <16 x i8> %0, %1
       %cmpres_s = sext <16 x i1> %cmpres to <16 x i8>
       %res = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> %cmpres_s)
       ret i32 %res"), 
       UInt32,Tuple{NTuple{16, VecElement{Int8}},NTuple{16, VecElement{Int8}}} ,v1.elts, v2.elts)
       end

@inline expand_v(x::Int8) = Vec{16, Int8}((x,x,x,x, x,x,x,x, x,x,x,x, x,x,x,x))

function vec_codes(codes::Vector) #utility function. not for runtime use
           codes = convert(Vector{Int8}, codes)
           ell = length(codes)
           0<ell<=16 || error("does not fit into vector")
           padding = ntuple(i->codes[1], 16-ell)
           return Vec{16, Int8}( (ntuple(i->codes[i], ell)...,padding...) )
       end

@inline function find_16(x, codes::Vec{16, Int8})
    xe = expand_v(x % Int8)
    cmpbits = cmp_16xi8_i32(xe, codes)
    cmpres = 1+ trailing_zeros(cmpbits % UInt16)
    return ifelse(cmpres==17, 0, cmpres)
end

This will find the earliest occurence of an Int8 in a 16 x Int8, one-based, and 0 if there was no occurence. The function vec_codes is not fast; it just serves as convenience constructor for the 16 x Int8, including padding.
Let’s test:

julia> vc = vec_codes(collect(1:15));

julia> for i=1:17
       @show i,find_16(i, vc)
       end
(i, find_16(i, vc)) = (1, 1)
(i, find_16(i, vc)) = (2, 2)
(i, find_16(i, vc)) = (3, 3)
(i, find_16(i, vc)) = (4, 4)
(i, find_16(i, vc)) = (5, 5)
(i, find_16(i, vc)) = (6, 6)
(i, find_16(i, vc)) = (7, 7)
(i, find_16(i, vc)) = (8, 8)
(i, find_16(i, vc)) = (9, 9)
(i, find_16(i, vc)) = (10, 10)
(i, find_16(i, vc)) = (11, 11)
(i, find_16(i, vc)) = (12, 12)
(i, find_16(i, vc)) = (13, 13)
(i, find_16(i, vc)) = (14, 14)
(i, find_16(i, vc)) = (15, 15)
(i, find_16(i, vc)) = (16, 0)
(i, find_16(i, vc)) = (17, 0)

And speed:

julia> @btime find_16($x, $vc);
  2.668 ns (0 allocations: 0 bytes)
julia> @code_native find_16(17, vc)
	.text
; Function find_16 {
; Location: REPL[5]:2
	vmovd	%edi, %xmm0
	vpbroadcastb	%xmm0, %xmm0
	vpcmpeqb	(%rsi), %xmm0, %xmm0
	vpmovmskb	%xmm0, %eax
	tzcntw	%ax, %ax
	addl	$1, %eax
	movzwl	%ax, %ecx
	xorl	%eax, %eax
	cmpl	$17, %ecx
	cmovneq	%rcx, %rax
	retq
	nopw	%cs:(%rax,%rax)
;}

5 cycles, no branch. When called in a loop"

julia> function mapfind!(dst, src, vc)
       @inbounds for i=1:length(dst)
       dst[i]=find_16(src[i], vc)
       end
       nothing
       end
julia> src=rand(Int8, 10_000); dst=collect(1:10_000);

julia> @btime mapfind!($dst, $src, $vc);
  18.146 μs (0 allocations: 0 bytes)

3.5 cycles per find.

PS. If you are on AMD instead of intel, then it is probably better to use leading_zeros, and reorganize vec_codes.

PPS.

julia> src .=rand(1:16, length(src));

julia> @btime broadcast!($findcode_binary, $dst, $src);
  14.003 μs (0 allocations: 0 bytes)

This is somewhat amazing, but the SSE2 solution works always, for length(codes) up to 16, and does not need to know codes at compile time. It was cribbed from https://github.com/armon/libart/blob/master/src/art.c, where it is used to quickly traverse 16-ary tree nodes.

Impressive work @foobar_lv2!

For completeness, here’s a lookup-table version:

const CODES = (2, 3, 5, 7, 11)

lookuptable(v) = (a = zeros(Int, maximum(v)); for (i,x) in enumerate(v); a[x]=i; end; a)
const LUT = lookuptable(CODES)

function findcode_lookuptable(x)
    x < 1 || x > length(LUT) ? 0 : LUT[x]
end

Sample data:

src = rand(1:16, 10_000);
dst = similar(src);

Timings below for the four versions presented in this thread, without the @noinline, and with Julia started with -O3:

julia> @btime broadcast!($findcode_original, $dst, $src);
  50.333 μs (0 allocations: 0 bytes)

julia> @btime broadcast!($findcode_lookuptable, $dst, $src);
  39.822 μs (0 allocations: 0 bytes)
  
julia> @btime broadcast!($findcode_simd, $dst, $src);
  11.496 μs (0 allocations: 0 bytes)

julia> @btime broadcast!($findcode_binary, $dst, $src);
  3.754 μs (0 allocations: 0 bytes)

Last year, poring over the output of @code_native I discovered that my code of the form

    if i == 1
        dothis()
    elseif i == 2
        dotheother()
    elseif i == 3
        dothethirdthing()
    end

was in fact compiling to a jump-table. (This was causing a bug in my program because there was an error in LLVM address computation. That LLVM bug has been fixed, I believe.) Here is the old thread:

If you don’t care about values out of bounds, there’s this:

findcode_shift(x) = (x+1)>>1-x>>3

Testing it (compare with timings in my previous post):

julia> findcode_shift.(CODES)
(1, 2, 3, 4, 5)

julia> @btime broadcast!($findcode_shift, $dst, $src);
  2.313 μs (0 allocations: 0 bytes)

That’s 0.67 cycles per find! That time includes memory read/writes though, which are not free:

julia> @btime broadcast!(identity, $dst, $src);
  1.615 μs (0 allocations: 0 bytes)

So the overhead of findcode_shift compared to simply copying the elements is 0.2 cycles! (calculated as 2.9*(2313-1615)/10000 for a 2.9 GHz CPU and 10k elements)

For even better performance, we can use a packed version of the bit shifting solution to convert 8 codes at a time:

const ONE = 0x101010101010101
const SEVEN = 0x707070707070707

findcode_shift_packed(x) = ((x+ONE)>>1-x>>3&ONE)&SEVEN

With the codes stored in an Int8 array, we can wrap it as an Int64 array to operate on it efficiently (same memory):

src8 = rand(Int8[2,3,5,7,11], 10_000);
src64 = unsafe_wrap(Array, convert(Ptr{Int64}, pointer(src8)), length(src8)>>3);
dst64 = similar(src64);

broadcast!(findcode_shift_packed, dst64, src64);

dst8 = unsafe_wrap(Array, convert(Ptr{Int8}, pointer(dst64)), length(dst64)<<3);

Verification:

julia> for i=1:6; print("$(src8[i])->$(dst8[i])  "); end
3->2  5->3  7->4  2->1  3->2  11->5

Timings:

julia> @btime broadcast!($findcode_shift_packed, $dst64, $src64);
  207.870 ns (0 allocations: 0 bytes)
  
julia> @btime broadcast!(identity, $dst64, $src64);
  101.603 ns (0 allocations: 0 bytes)

code_native shows that this method uses 256 bit SSE2 instructions, operating on 32 codes in parallel. The timings above show that ~16.7 codes are handled per clock cycle. Correcting for the overhead of copying elements, ~32.0 codes are handled per clock cycle. 460 times faster than the original implementation!

Hey @bennedich , since you obviously share the hobby of micro-optimizations, I have a challenge for you (sorry for derailing the conversation):

findall of BitArray really needs some love. Logical indexing by BitArray will get a 5x speedup in https://github.com/JuliaLang/julia/pull/29746, to 3 cycles/selected index (on intel Broadwell) when most indices are selected (we only pay a branch miss on every 64 bit of the BitArray).

We could use the same code in findall(B::BitVector), but I have no idea how to do fast findall(B::BitMatrix): We need to produce cartesian, not linear indices.

On the other hand, we can go crazy here: We own all of the context, not a single call to code we do not control. Nobody prevents us from e.g. first collecting linear indices and then batch-converting them in-place to cartesian by using a magic AVX bit-manipulation something (as long as we find a pure llvm idiom; declare ?? @llvm.x86.?? is obviously not admissible in Base, since julia supports more architectures than x86). Even allocation of temporaries can probably amortize. Integer division is unlikely to pay off, though.

Challenge accepted!

I made an attempt here. Would appreciate if you’d like to take a look @foobar_lv2 and see if you can improve it further. There’s nothing crazy really, I basically just took your work and extended it to more dimensions.

Is anyone able to change the post that is marked as the solution? Unless I’ve misunderstood, my benchmark of the accepted suggestion is:

@btime broadcast!($findcode3, $dst, $src);
  99.561 μs (0 allocations: 0 bytes)

compared to bennedich’s suggestion:

@btime broadcast!($findcode_lookuptable, $dst, $src);
  73.449 μs (0 allocations: 0 bytes)

I’m not sure if the findcode_shift_packed meets the criteria since it appears that you’d need to add some conversion and checking overhead to make the results equivalent/comparable.

I picked the solution because it represents the sweet spot for me between simplicity and speed.

Hmm, working backwards I see that for Julia 1.0.2 your original code appears to be the fastest!?:

@btime broadcast!($findcode, $dst64, $src64);
  10.332 μs (0 allocations: 0 bytes)

You seem to be using the 64 bit packed arrays in that last benchmark, which are 8 times smaller than the original arrays, hence your code is ~8 times faster.

By the way, a lot of interesting discussion spawned from this question, including this topic which suggests that we should have been using larger vectors in these benchmarks. For 100k vectors, I get these timings:

julia> src = rand(1:16, 100_000);

julia> dst = similar(src);

julia> broadcast!($findcode, $dst, $src);
  510.262 μs (0 allocations: 0 bytes)

julia> broadcast!($findcode3, $dst, $src);
  637.381 μs (0 allocations: 0 bytes)

julia> broadcast!($findcode_lookuptable, $dst, $src);
  398.854 μs (0 allocations: 0 bytes)

So indeed the original solution seems to be faster than the accepted solution (but not as fast as a look-up-table), and much more readable IMO.