Compiling to branch table

question

#1

Is there a way to coax the compiler into compiling something like

const CODES = (2, 3, 5, 7, 11)

function findcode(x)
    index = findfirst(isequal(x), CODES)
    index ≡ nothing ? 0 : index
end

similarly to

function findcode2(x)
    if x == 2
        1
    elseif x == 3
        2
    elseif x == 5
        3
    elseif x == 7
        4
    elseif x == 11
        5
    else
        0
    end
end

other than direct code generation (eg a macro)?


#2

I don’t know of anything except a macro. What will your codes consist of? Is this performance critical? You could probably gain some performance by a) customized binary search instead of a series of ifs and b) sorting your checks by likelihood of occurrence (e.g. if 7 is the most frequent value, then check that first).


#3

https://github.com/JuliaCollections/Memoize.jl?


#4

That’s unlikely (only showing relevant part):

julia> @code_llvm findcode2(3)

; Function findcode2
; Location: REPL[15]:2
define i64 @julia_findcode2_1216615403(i64) {
top:
  switch i64 %0, label %L13 [
    i64 2, label %L3
    i64 3, label %L3.fold.split
    i64 5, label %L3.fold.split1
    i64 7, label %L3.fold.split2
  ]

#5

How large is codes (how many elements) and how large is x (number of bits)?

For example, if x is Int8, then you can compare it to a 16 element codes in a single low-latency instruction (need to broadcast x into a register first). This will end up much faster (and branchfree).

Does anyone know the correct idiom to get llvm to emit the equivalent of _mm_movemask_epi8, without llvmcall into @llvm.x86.avx.movmsk?


#6

Here’s a simple recursive hack: (edit: cleaned up the code a bit)

using StaticNumbers

const CODES = (2, 3, 5, 7, 11)
const SRCODES = static.(reverse(CODES))

f(x) = 0
f(x, y, ys...) = x == y ? 1 + length(ys) : f(x, ys...)
findcode3(x) = f(x, SRCODES...)

@code_llvm findcode3(3)

Results in (only showing the switch statement):

define i64 @julia_findcode3_35284(i64) {
top:
  switch i64 %0, label %L13 [
    i64 11, label %L25
    i64 7, label %L25.fold.split
    i64 5, label %L25.fold.split5
    i64 3, label %L25.fold.split6
  ]

The function static turns the CODES into compile-time constants, which results in the switch statement being compiled specifically for this list. Edit: Let me rephrase that. They’re already compile-time constants, but the compiler may or may not treat them as constant when compiling f. In fact, after I re-wrote the code to make findcode3 pass the list instead of doing it from the REPL, constant propagation kicks in, and StaticNumbers is no longer needed.


#7

Constant propagation works amazingly well in Julia, as I found out when I tried calling f from within a function. There’s no need for StaticNumbers here. The below is enough.

const CODES = (2, 3, 5, 7, 11)
const RCODES = reverse(CODES)

f(x) = 0
f(x, y, ys...) = x == y ? 1 + length(ys) : f(x, ys...)
findcode3(x) = f(x, RCODES...)

@code_llvm findcode3(3)

#8

I don’t think it is. What you have there is the LLVM bitcode. As far as I know, once that code is converted to assembly, the compiler is not going to to turn it into a simple jump table, but a bunch of comparing and branching. (I’m not even sure if a jump table would be better since it might mess with branch prediction.) Hard-coding a binary search will give you better performance then. Example:

@noinline function findcode2(x)
    if x == 2
        1
    elseif x == 3
        2
    elseif x == 5
        3
    elseif x == 7
        4
    elseif x == 11
        5
    else
        0
    end
end

@noinline function findcode_binary(x)
    if x < 5
        x == 2 && return 1
        x == 3 && return 2
    else
        x == 5 && return 3
        x == 7 && return 4
    end
    return 0
end

Timings:

julia> @btime (for k = 1:1000; for n = 1:8; findcode2(n); end; end)
  15.194 μs (0 allocations: 0 bytes)

julia> @btime (for k = 1:1000; for n = 1:8; findcode_binary(n); end; end)
  10.840 μs (0 allocations: 0 bytes)

Even for 5 values, this approach is more efficient. If the table size grows, I’d expect the difference to be bigger.

A lookup-table would be another option. However, it will require a memory access, and won’t be branch free anyway if you need to do bounds checks (default value).


#9

If you want to go the vector route and are happy with x86 specific code, here you go:

using SIMD

@inline function cmp_16xi8_i32(v1::Vec{16, Int8}, v2::Vec{16, Int8})
       Base.llvmcall(("declare i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8>) nounwind readnone", "
       %cmpres = icmp eq <16 x i8> %0, %1
       %cmpres_s = sext <16 x i1> %cmpres to <16 x i8>
       %res = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> %cmpres_s)
       ret i32 %res"), 
       UInt32,Tuple{NTuple{16, VecElement{Int8}},NTuple{16, VecElement{Int8}}} ,v1.elts, v2.elts)
       end

@inline expand_v(x::Int8) = Vec{16, Int8}((x,x,x,x, x,x,x,x, x,x,x,x, x,x,x,x))

function vec_codes(codes::Vector) #utility function. not for runtime use
           codes = convert(Vector{Int8}, codes)
           ell = length(codes)
           0<ell<=16 || error("does not fit into vector")
           padding = ntuple(i->codes[1], 16-ell)
           return Vec{16, Int8}( (ntuple(i->codes[i], ell)...,padding...) )
       end

@inline function find_16(x, codes::Vec{16, Int8})
    xe = expand_v(x % Int8)
    cmpbits = cmp_16xi8_i32(xe, codes)
    cmpres = 1+ trailing_zeros(cmpbits % UInt16)
    return ifelse(cmpres==17, 0, cmpres)
end

This will find the earliest occurence of an Int8 in a 16 x Int8, one-based, and 0 if there was no occurence. The function vec_codes is not fast; it just serves as convenience constructor for the 16 x Int8, including padding.
Let’s test:

julia> vc = vec_codes(collect(1:15));

julia> for i=1:17
       @show i,find_16(i, vc)
       end
(i, find_16(i, vc)) = (1, 1)
(i, find_16(i, vc)) = (2, 2)
(i, find_16(i, vc)) = (3, 3)
(i, find_16(i, vc)) = (4, 4)
(i, find_16(i, vc)) = (5, 5)
(i, find_16(i, vc)) = (6, 6)
(i, find_16(i, vc)) = (7, 7)
(i, find_16(i, vc)) = (8, 8)
(i, find_16(i, vc)) = (9, 9)
(i, find_16(i, vc)) = (10, 10)
(i, find_16(i, vc)) = (11, 11)
(i, find_16(i, vc)) = (12, 12)
(i, find_16(i, vc)) = (13, 13)
(i, find_16(i, vc)) = (14, 14)
(i, find_16(i, vc)) = (15, 15)
(i, find_16(i, vc)) = (16, 0)
(i, find_16(i, vc)) = (17, 0)

And speed:

julia> @btime find_16($x, $vc);
  2.668 ns (0 allocations: 0 bytes)
julia> @code_native find_16(17, vc)
	.text
; Function find_16 {
; Location: REPL[5]:2
	vmovd	%edi, %xmm0
	vpbroadcastb	%xmm0, %xmm0
	vpcmpeqb	(%rsi), %xmm0, %xmm0
	vpmovmskb	%xmm0, %eax
	tzcntw	%ax, %ax
	addl	$1, %eax
	movzwl	%ax, %ecx
	xorl	%eax, %eax
	cmpl	$17, %ecx
	cmovneq	%rcx, %rax
	retq
	nopw	%cs:(%rax,%rax)
;}

5 cycles, no branch. When called in a loop"

julia> function mapfind!(dst, src, vc)
       @inbounds for i=1:length(dst)
       dst[i]=find_16(src[i], vc)
       end
       nothing
       end
julia> src=rand(Int8, 10_000); dst=collect(1:10_000);

julia> @btime mapfind!($dst, $src, $vc);
  18.146 μs (0 allocations: 0 bytes)

3.5 cycles per find.

PS. If you are on AMD instead of intel, then it is probably better to use leading_zeros, and reorganize vec_codes.

PPS.

julia> src .=rand(1:16, length(src));

julia> @btime broadcast!($findcode_binary, $dst, $src);
  14.003 μs (0 allocations: 0 bytes)

This is somewhat amazing, but the SSE2 solution works always, for length(codes) up to 16, and does not need to know codes at compile time. It was cribbed from https://github.com/armon/libart/blob/master/src/art.c, where it is used to quickly traverse 16-ary tree nodes.


Julia equivalent of C compiler intrinsics?
#10

Impressive work @foobar_lv2!

For completeness, here’s a lookup-table version:

const CODES = (2, 3, 5, 7, 11)

lookuptable(v) = (a = zeros(Int, maximum(v)); for (i,x) in enumerate(v); a[x]=i; end; a)
const LUT = lookuptable(CODES)

function findcode_lookuptable(x)
    x < 1 || x > length(LUT) ? 0 : LUT[x]
end

Sample data:

src = rand(1:16, 10_000);
dst = similar(src);

Timings below for the four versions presented in this thread, without the @noinline, and with Julia started with -O3:

julia> @btime broadcast!($findcode_original, $dst, $src);
  50.333 μs (0 allocations: 0 bytes)

julia> @btime broadcast!($findcode_lookuptable, $dst, $src);
  39.822 μs (0 allocations: 0 bytes)
  
julia> @btime broadcast!($findcode_simd, $dst, $src);
  11.496 μs (0 allocations: 0 bytes)

julia> @btime broadcast!($findcode_binary, $dst, $src);
  3.754 μs (0 allocations: 0 bytes)

#11

Last year, poring over the output of @code_native I discovered that my code of the form

    if i == 1
        dothis()
    elseif i == 2
        dotheother()
    elseif i == 3
        dothethirdthing()
    end

was in fact compiling to a jump-table. (This was causing a bug in my program because there was an error in LLVM address computation. That LLVM bug has been fixed, I believe.) Here is the old thread:


#12

If you don’t care about values out of bounds, there’s this:

findcode_shift(x) = (x+1)>>1-x>>3

Testing it (compare with timings in my previous post):

julia> findcode_shift.(CODES)
(1, 2, 3, 4, 5)

julia> @btime broadcast!($findcode_shift, $dst, $src);
  2.313 μs (0 allocations: 0 bytes)

That’s 0.67 cycles per find! That time includes memory read/writes though, which are not free:

julia> @btime broadcast!(identity, $dst, $src);
  1.615 μs (0 allocations: 0 bytes)

So the overhead of findcode_shift compared to simply copying the elements is 0.2 cycles! (calculated as 2.9*(2313-1615)/10000 for a 2.9 GHz CPU and 10k elements)


#13

For even better performance, we can use a packed version of the bit shifting solution to convert 8 codes at a time:

const ONE = 0x101010101010101
const SEVEN = 0x707070707070707

findcode_shift_packed(x) = ((x+ONE)>>1-x>>3&ONE)&SEVEN

With the codes stored in an Int8 array, we can wrap it as an Int64 array to operate on it efficiently (same memory):

src8 = rand(Int8[2,3,5,7,11], 10_000);
src64 = unsafe_wrap(Array, convert(Ptr{Int64}, pointer(src8)), length(src8)>>3);
dst64 = similar(src64);

broadcast!(findcode_shift_packed, dst64, src64);

dst8 = unsafe_wrap(Array, convert(Ptr{Int8}, pointer(dst64)), length(dst64)<<3);

Verification:

julia> for i=1:6; print("$(src8[i])->$(dst8[i])  "); end
3->2  5->3  7->4  2->1  3->2  11->5

Timings:

julia> @btime broadcast!($findcode_shift_packed, $dst64, $src64);
  207.870 ns (0 allocations: 0 bytes)
  
julia> @btime broadcast!(identity, $dst64, $src64);
  101.603 ns (0 allocations: 0 bytes)

code_native shows that this method uses 256 bit SSE2 instructions, operating on 32 codes in parallel. The timings above show that ~16.7 codes are handled per clock cycle. Correcting for the overhead of copying elements, ~32.0 codes are handled per clock cycle. 460 times faster than the original implementation!


#14

Hey @bennedich , since you obviously share the hobby of micro-optimizations, I have a challenge for you (sorry for derailing the conversation):

findall of BitArray really needs some love. Logical indexing by BitArray will get a 5x speedup in https://github.com/JuliaLang/julia/pull/29746, to 3 cycles/selected index (on intel Broadwell) when most indices are selected (we only pay a branch miss on every 64 bit of the BitArray).

We could use the same code in findall(B::BitVector), but I have no idea how to do fast findall(B::BitMatrix): We need to produce cartesian, not linear indices.

On the other hand, we can go crazy here: We own all of the context, not a single call to code we do not control. Nobody prevents us from e.g. first collecting linear indices and then batch-converting them in-place to cartesian by using a magic AVX bit-manipulation something (as long as we find a pure llvm idiom; declare ?? @llvm.x86.?? is obviously not admissible in Base, since julia supports more architectures than x86). Even allocation of temporaries can probably amortize. Integer division is unlikely to pay off, though.


#15

Challenge accepted!

I made an attempt here. Would appreciate if you’d like to take a look @foobar_lv2 and see if you can improve it further. There’s nothing crazy really, I basically just took your work and extended it to more dimensions.