How to shift bits faster

I thought I’d share something that I’ve run into before, and that I’ve seen a couple of times now on this forum (most recently in this topic).

If you want to shift the integer n to the left or right by k bits, the natural way to do it is simply n << k or n >> k. However, this is not the most efficient way, since Julia’s shift operator differs from the native shift operator. On 64-bit CPUs, the shift count is masked with 63, so trying to shift to the left with 65 bits results in shifting to the left with 1 bit. In Julia however, the actual count is used, so shifting to the left with 65 bits will zero the result. This means that Julia’s shift operator doesn’t translate very well to native code:

julia> code_native((n,k) -> n << k, (Int,Int); syntax=:intel)
	xor    eax, eax
	cmp    rsi, 63
	shlx   rcx, rdi, rsi
	cmovbe rax, rcx
	mov    rcx, rsi
	neg    rcx
	cmp    rcx, 63
	jb     L29
	mov    cl, 63
L29:
	test   rsi, rsi
	sarx   rcx, rdi, rcx
	cmovs  rax, rcx
	ret

Luckily, there’s an easy way to improve this; masking the count to 6 bits (n << (k&63)) generates efficient native code:

julia> code_native((n,k) -> n << (k&63), (Int,Int); syntax=:intel)
	shlx	rax, rdi, rsi
	ret

A (very artificial) benchmark:

julia> A = rand(0:63, 100_000);

julia> @btime (s = 0; @inbounds @simd for k = 1:length($A); s += k << $A[k]; end; s)
  42.788 μs (0 allocations: 0 bytes)
-2271119849451809947

julia> @btime (s = 0; @inbounds @simd for k = 1:length($A); s += k << ($A[k] & 63); end; s)
  15.195 μs (0 allocations: 0 bytes)
-2271119849451809947

This is unlikely to lead to any major performance improvements in your code (probably not even noticeable), but for anyone micro-optimizing and studying native code, this might be of interest.

17 Likes

Thankfully, the compiler seems good at figuring that sort of thing out in practice. Eg, your & 63 example, or when the shift is constant:

julia> leftshift18(x) = x << 18
leftshift18 (generic function with 1 method)

julia> rightshift18(x) = x >> 18
rightshift18 (generic function with 1 method)

julia> @code_native leftshift18(1)
	.text
; ┌ @ REPL[1]:1 within `leftshift18'
; │┌ @ int.jl:450 within `<<' @ REPL[1]:1
	shlq	$18, %rdi
; │└
	movq	%rdi, %rax
	retq
	nopl	(%rax,%rax)
; └

julia> @code_native rightshift18(1)
	.text
; ┌ @ REPL[2]:1 within `rightshift18'
; │┌ @ int.jl:448 within `>>' @ REPL[2]:1
	sarq	$18, %rdi
; │└
	movq	%rdi, %rax
	retq
	nopl	(%rax,%rax)
; └

Or in common usages like random number generators:

julia> using RandomNumbers

julia> @code_native RandomNumbers.PCG.pcg_output(one(UInt), RandomNumbers.PCG.PCG_XSH_RS)
	.text
; ┌ @ bases.jl:85 within `pcg_output'
; │┌ @ int.jl:448 within `>>' @ bases.jl:74
	movq	%rdi, %rax
	shrq	$22, %rax
; │└
; │┌ @ int.jl:321 within `xor'
	xorq	%rdi, %rax
; │└
; │ @ bases.jl:84 within `pcg_output'
; │┌ @ int.jl:448 within `>>' @ int.jl:442
	shrq	$61, %rdi
; │└
; │ @ bases.jl:86 within `pcg_output'
; │┌ @ int.jl:800 within `+' @ int.jl:53
	addl	$22, %edi
; │└
; │┌ @ int.jl:442 within `>>'
	shrxq	%rdi, %rax, %rax
; │└
	retq
	nopw	(%rax,%rax)
; └

The above function has quite a few shifts:

@inline function pcg_output(state::T, ::Type{PCG_XSH_RS}) where T <: Union{pcg_uints[2:end]...}
    return_bits = sizeof(T) << 2
    bits = return_bits << 1
    spare_bits = bits - return_bits
    op_bits = spare_bits - 5 >= 64 ? 5 :
              spare_bits - 4 >= 32 ? 4 :
              spare_bits - 3 >= 16 ? 3 :
              spare_bits - 2 >= 4  ? 2 :
              spare_bits - 1 >= 1  ? 1 : 0
    mask = (1 << op_bits) - 1
    xshift = op_bits + (return_bits + mask) >> 1
    rshift = op_bits != 0 ? (state >> (bits - op_bits)) & mask : 0 % T
    state = state ⊻ (state >> xshift)
    (state >> (spare_bits - op_bits - mask + rshift)) % half_width(T)
end
2 Likes
function foo(x::UInt64)
        x >> trailing_zeros(x)
end

julia> code_native(foo)
ERROR: no unique matching method found for the specified argument types
ERROR: no unique matching method found for the specified argument types
Stacktrace:
 [1] which(::Any, ::Any) at .\reflection.jl:926
 [2] _dump_function(::Any, ::Any, ::Bool, ::Bool, ::Bool, ::Bool, ::Symbol, ::Bool, ::Base.CodegenParams) at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.0\InteractiveUtils\src\codeview.jl:64
 [3] _dump_function at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.0\InteractiveUtils\src\codeview.jl:58 [inlined] (repeats 2 times)
 [4] #code_native#8 at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.0\InteractiveUtils\src\codeview.jl:124 [inlined]
 [5] (::getfield(InteractiveUtils, Symbol("#kw##code_native")))(::NamedTuple{(:syntax,),Tuple{Symbol}}, ::typeof(code_native), ::Base.TTY, ::Function, ::Type) at .\none:0
 [6] #code_native#9(::Symbol, ::Function, ::Any, ::Any) at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.0\InteractiveUtils\src\codeview.jl:126
 [7] code_native(::Any, ::Any) at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.0\InteractiveUtils\src\codeview.jl:126 (repeats 2 times)
 [8] top-level scope at none:0

what am I doing wrong?

you omitted the argument type[s]:


julia> code_native(foo, (UInt64,))
        .text
; Function foo {
; Location: REPL[10]:2
        pushq   %rbp
        movq    %rsp, %rbp
; Function trailing_zeros; {
; Location: int.jl:384
        tzcntq  %rcx, %rdx
;}
; Function >>; {
; Location: int.jl:448
; Function >>; {
; Location: int.jl:442
        shrxq   %rdx, %rcx, %rcx
        xorl    %eax, %eax
        cmpq    $63, %rdx
        cmovbeq %rcx, %rax
;}}
        popq    %rbp
        retq
        nopw    (%rax,%rax)
;}

thx @JeffreySarnoff , this look better thanks @bennedich :wink:

function foo(x::UInt64)
        x >> (trailing_zeros(x)&63)
end

code_native(foo, (UInt64,); syntax=:intel)

       .text
; Function foo {
; Location: 5cardsranker.jl:157
        push    rbp
        mov     rbp, rsp
; Function trailing_zeros; {
; Location: int.jl:384
        tzcnt   rax, rcx
;}
; Function >>; {
; Location: int.jl:448
; Function >>; {
; Location: int.jl:442
        shrx    rax, rcx, rax
;}}
        pop     rbp
        ret
;}

Yes, for constant shifts, the compiler seems to handle it well.

Keep in mind that most of that function (everything that only depends on the type) is hardcoded at compilation time. What you’re left with at runtime for an UInt type is this:

@inline function pcg_output_compiled(state::UInt)
    rshift = (state >> 61) & 7
    state = state ⊻ (state >> 22)
    (state >> (22 + rshift)) % UInt32
end

See how they do & 7, this seems to help the compiler figure out that the shift is within the correct range. Looking at some other methods in that file, e.g. pcg_rotr, the authors seem to have been well aware of this issue and done a good job optimizing the code.

However, for a lot of other code I’ve seen, that’s not the case. Grepping for the string >> in the base source and looking for functions that do non-masked shifting with a variable, I arbitrarily chose gcd in intfuncs.jl:

# binary GCD (aka Stein's) algorithm
# about 1.7x (2.1x) faster for random Int64s (Int128s)
function gcd(a::T, b::T) where T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}
    @noinline throw1(a, b) = throw(OverflowError("gcd($a, $b) overflows"))
    a == 0 && return abs(b)
    b == 0 && return abs(a)
    za = trailing_zeros(a)
    zb = trailing_zeros(b)
    k = min(za, zb)
    u = unsigned(abs(a >> za))
    v = unsigned(abs(b >> zb))
    while u != v
        if u > v
            u, v = v, u
        end
        v -= u
        v >>= trailing_zeros(v)
    end
    r = u << k
    # T(r) would throw InexactError; we want OverflowError instead
    r > typemax(T) && throw1(a, b)
    r % T
end

Four shifts, all of which are guaranteed to be less than the type bit count, but not masked. Inspecting with code_native, there indeed seems to be a lot of unnecessary range checking. So what happens if we simply mask each shift?

function gcd_masked(a::T, b::T) where T<:Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Int128,UInt128}
    @noinline throw1(a, b) = throw(OverflowError("gcd($a, $b) overflows"))
    a == 0 && return abs(b)
    b == 0 && return abs(a)
    za = trailing_zeros(a)
    zb = trailing_zeros(b)
    k = min(za, zb)
    bits = sizeof(T) << 3
    u = unsigned(abs(a >> (za % bits)))
    v = unsigned(abs(b >> (zb % bits)))
    while u != v
        if u > v
            u, v = v, u
        end
        v -= u
        v >>= (trailing_zeros(v) % bits)
    end
    r = u << (k % bits)
    # T(r) would throw InexactError; we want OverflowError instead
    r > typemax(T) && throw1(a, b)
    r % T
end

Let’s compare the performance with some random ints:

julia> A = rand(Int, 100_000);

julia> @btime (s = 0; @inbounds for n = 1:length($A)-1 s += gcd($A[n], $A[n+1]) end; s)
  19.087 ms (0 allocations: 0 bytes)
1946053

julia> @btime (s = 0; @inbounds for n = 1:length($A)-1 s += gcd_masked($A[n], $A[n+1]) end; s)
  9.785 ms (0 allocations: 0 bytes)
1946053

From ~191 ns per call to ~98 ns per call, or a 1.95x improvement – that’s a larger improvement than that of using Stein’s binary algorithm over simple repeated division. So, unless I’m missing something in this benchmark, there seems to be some room for improvement even in base code.

Note: The above benchmark and improvement was observed on a 2.9 GHz Skylake CPU. Rerunning the
same benchmark on a 2.6 GHz Broadwell CPU, both versions ran in about ~28 ms, with almost no improvement for the masked one. I haven’t looked into why.

5 Likes

Wow. This is quite a performance pitfall I was not aware of, thanks!

Can we put that into the docstring of the shift operators? Especially x>>trailing_zeros(x) is an important idiom.

Also, WTF llvm and C. If I read the langref right then a rightshift lshr by >=64 bit is undef; hence, it would be entirely legal to remove the branch that returns 42 in this example, because in clang’s interpretation of C, x>>trailing_zeros(x) of an unsigned int is either nonzero or undefined (llvm poison value). I guess somebody should spec that shifting a zero by any number of bits is always zero before somebody teaches llvm that 0 is the only i64 with 64 trailing zeros.

FYI: I opened a GitHub issue.

1 Like

Could this specific case (gcd) be handled by LLVM with another optimization pass ?

Looking at the output of @code_llvm gcd(1,1), I notice these t lines:

   %28 = call i64 @llvm.cttz.i64(i64 %0, i1 false)
   ...
   %30 = icmp ult i64 %28, 63

Here %28 (i.e. za) could be 64, if %0 (i.e. a) is zero, so LLVM thinks that the icmp ult is necessary. But we know that a is nonzero after a == 0 && return abs(b), so %28 must be 63 or less.

I don’t know enough about writing LLVM optimization passes to say whether it is feasible, but if it is, then it seems like the right place to put the fix.