Why isn't minmax branchless?

minmax for two numbers is defined by

My question is: why not make it branchless with

minmax(x, y) = ifelse(isless(x, y), (x, y), (y, x))
5 Likes

Do you have a specific example of the difference?

Edit: I always trust experimental facts over my memory. :slight_smile:

We know this will stop SIMD from happening (even without an example) right? (I forgot if currently ifelse can be SIMD or not)

I recently heard the compiler can remove simple branches. Maybe this is one of those cases?

3 Likes

The generated machine code is indeed different if use solely.
With branch (the version in the Base):

@code_native minmax(1,2)
	.text
; β”Œ @ promotion.jl:423 within `minmax'
	movq	%rdi, %rax
; β”‚β”Œ @ int.jl:83 within `<'
	cmpq	%rsi, %rdx
; β”‚β””
	jge	L16
	movq	%rdx, (%rax)
	movq	%rsi, 8(%rax)
	retq
L16:
	movq	%rsi, (%rax)
	movq	%rdx, 8(%rax)
	retq
	nopl	(%rax,%rax)
; β””

Without branch:

julia> myminmax(x, y) = ifelse(isless(x, y), (x, y), (y, x))
myminmax (generic function with 1 method)

julia> @code_native myminmax(1,2)
	.text
; β”Œ @ REPL[2]:1 within `myminmax'
	movq	%rdi, %rax
; β”‚β”Œ @ operators.jl:357 within `isless'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
	cmpq	%rdx, %rsi
; β”‚β””β””
	movq	%rsi, %rcx
	cmovlq	%rdx, %rcx
	cmovlq	%rsi, %rdx
	movq	%rdx, (%rdi)
	movq	%rcx, 8(%rdi)
	retq
	nopl	(%rax)
; β””

So the version without branch has more move instruction.
But they can be the same if they get inlined into other functions.

julia> function swap(x::Ref)
       x[] = minmax(x[]..)
       nothing
end
julia> @code_native swap(Ref((1,2)))
	.text
; β”Œ @ REPL[8]:1 within `swap'
	movq	%rsi, -8(%rsp)
	movq	(%rsi), %rax
; β”‚ @ REPL[8]:2 within `swap'
; β”‚β”Œ @ refvalue.jl:56 within `getindex'
; β”‚β”‚β”Œ @ Base.jl:33 within `getproperty'
	movq	(%rax), %rcx
	movq	8(%rax), %rdx
; β”‚β””β””
; β”‚β”Œ @ promotion.jl:423 within `minmax'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
	cmpq	%rcx, %rdx
; β”‚β”‚β””
	movq	%rcx, %rsi
	cmovlq	%rdx, %rsi
	cmovlq	%rcx, %rdx
; β”‚β””
; β”‚β”Œ @ refvalue.jl:57 within `setindex!'
; β”‚β”‚β”Œ @ Base.jl:34 within `setproperty!'
	movq	%rsi, (%rax)
	movq	%rdx, 8(%rax)
	movabsq	$jl_system_image_data, %rax
; β”‚β””β””
; β”‚ @ REPL[8]:3 within `swap'
	retq
; β””

Some even more interesting observations here:

function minmax1(x,y) 
    l = isless(y, x) ? (y, x) : (x, y)
    return l
end

minmax1 is essentially equivalent to minmax, just add a explicit return clause here. You get:

julia> @code_native minmax1(1,2)
	.text
; β”Œ @ REPL[22]:1 within `minmax1'
	movq	%rdi, %rax
; β”‚ @ REPL[22]:2 within `minmax1'
; β”‚β”Œ @ operators.jl:357 within `isless'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
	cmpq	%rsi, %rdx
; β”‚β””β””
	movq	%rsi, %rcx
	cmovlq	%rdx, %rcx
	cmovlq	%rsi, %rdx
; β”‚ @ REPL[22]:3 within `minmax1'
	movq	%rcx, (%rdi)
	movq	%rdx, 8(%rdi)
	retq
	nopl	(%rax)
; β””

, which is the same as the branchless version. What happens here is that minmax(x,y) = isless(y, x) ? (y, x) : (x, y) has two return statements (one for each branch) and minmax1 has only one. I guess it’s easier for compiler to optimize the latter case.
In summary, minmax with or without branch are almost the same. Compiler can replace the branch with a select (then lower to a conditional move). I think it’s better to just let the compiler do the choice here…

9 Likes

The compiler seems to be failing.
Also, you should define a branchless minmax like this:

 @inline minmax_nb(x,y) = (xlty = x < y; (ifelse(xlty, x, y), ifelse(xlty, y, x)))

(optionally, add a using IfElse: ifelse).

This is much easier for the compiler to understand than

@inline minmax_nb2(x,y) = (xlty = x < y; ifelse(xlty, (x,y), (y,x)))

Here is the associated Julia issue.

Testing how well each of the versions performs:

julia> function vminmax_map!(f::F, mi, ma, x, y) where {F}
           @inbounds for i ∈ eachindex(mi,ma,x,y)
               mi[i], ma[i] = f(x[i], y[i])
           end
       end
vminmax_map! (generic function with 1 method)

julia> x = rand(400); y = rand(400); mi = similar(x); ma = similar(x);

julia> x = rand(Int,400); y = rand(Int,400); mi = similar(x); ma = similar(x);

julia> @benchmark vminmax_map!(minmax, $mi, $ma, $x, $y)
BechmarkTools.Trial: 10000 samples with 994 evaluations.
 Range (min … max):  29.058 ns … 122.268 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     29.180 ns               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   29.570 ns Β±   1.331 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–…β–ˆβ–†       β–„β–†β–„ ▁                                         β–ƒβ–‚   β–‚
  β–ˆβ–ˆβ–ˆβ–‡β–β–β–β–ƒβ–β–ƒβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–…β–„β–…β–β–ƒβ–†β–ˆβ–ˆβ–†β–†β–†β–…β–†β–…β–…β–…β–„β–β–„β–„β–„β–„β–„β–…β–„β–…β–…β–β–ƒβ–ƒβ–β–β–„β–„β–„β–„β–β–‡β–ˆβ–ˆβ–…β–„ β–ˆ
  29.1 ns       Histogram: log(frequency) by time      32.9 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark vminmax_map!(minmax_nb2, $mi, $ma, $x, $y)
BechmarkTools.Trial: 10000 samples with 585 evaluations.
 Range (min … max):  202.720 ns …  1.593 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     206.496 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   209.327 ns Β± 35.243 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

     β–…β–‡β–‡β–ˆβ–†β–…β–ƒ                                             ▁▁▁   β–‚
  β–†β–†β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–…β–†β–…β–„β–„β–…β–…β–…β–…β–…β–„β–…β–†β–†β–†β–†β–…β–„β–…β–…β–…β–…β–…β–†β–…β–…β–†β–…β–„β–…β–‚β–…β–„β–…β–„β–„β–‚β–„β–ƒβ–„β–„β–…β–†β–ˆβ–ˆβ–ˆβ–ˆβ–‡ β–ˆ
  203 ns        Histogram: log(frequency) by time       238 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @benchmark vminmax_map!(minmax_nb, $mi, $ma, $x, $y)
BechmarkTools.Trial: 10000 samples with 994 evaluations.
 Range (min … max):  29.075 ns … 60.479 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     29.767 ns              β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   29.849 ns Β±  1.466 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–ˆβ–‡     β–‡β–ˆβ–‚β–    ▂▁▂                      β–„β–ƒ                  β–‚
  β–ˆβ–ˆβ–‡β–β–β–ƒβ–β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–†β–‡β–„β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–†β–†β–…β–…β–…β–†β–…β–…β–…β–…β–ƒβ–…β–†β–†β–„β–„β–„β–†β–†β–ˆβ–ˆβ–‡β–…β–†β–†β–„β–„β–†β–„β–†β–„β–†β–†β–…β–…β–…β–„β–… β–ˆ
  29.1 ns      Histogram: log(frequency) by time      34.3 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

So the compiler fails to optimize unless we’re careful with our implementation. minmax_nb (with two ifelse statements) was 5x faster than a single ifelse statement.
The branch version is equally fast as the two ifelses.

22 Likes

BTW, in the case of floating point numbers, I don’t think the comparison is fair.

julia> minmax_nb(0.0, NaN)
(NaN, 0.0)

julia> minmax_nb(-0.0, 0.0)
(0.0, -0.0)
1 Like

Post an issue?

Oops, you’re right. I edited my above comment.
This would be the correct floating point definition:

 function minmax_nb(x::T, y::T) where {T<:AbstractFloat}
    isnanx = isnan(x)
    isnany = isnan(y)
    ygx = y > x
    sbxgsby = signbit(x) > signbit(y)
    l = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, x, y))
    u = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, y, x))
    l, u
end

but as the base floating point defintion uses ifelse instead of branches, I’ll have to correct my above post. I should’ve checked @less minmax(1.0,2.0) to confirm it’s actually using branches (it isn’t).

I’ll try integers, where LLVM should hopefully be able to convert branches to select…

3 Likes

So I guess the conclusion here is that branch is better than ifelse? Since LLVM can already perform if-conversion, so use of branch should not be problematic (at least in the simple case). In contrast, Julia’s frontend can’t correctly remove use of ifelse.
Actually, if we check the code of @less minmax(1.0,2.0). And replace all the ifelse by branch, we get the vectorized code:

#myminmax is the same as minmax, except that ifelse is replaced by a branch. 
@inline myifelse(b,x,y) = b ? x : y
myminmax(x::T, y::T) where {T<:AbstractFloat} =
    myifelse(isnan(x) | isnan(y), myifelse(isnan(x), (x,x), (y,y)),
           myifelse((y > x) | (signbit(x) > signbit(y)), (x,y), (y,x)))

We get:

vector.body:
    %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
    %24 = getelementptr inbounds double, double* %20, i64 %index
    %25 = bitcast double* %24 to <4 x double>*
    %wide.load = load <4 x double>, <4 x double>* %25, align 8
    %26 = getelementptr inbounds double, double* %21, i64 %index
    %27 = bitcast double* %26 to <4 x double>*
    %wide.load40 = load <4 x double>, <4 x double>* %27, align 8
    %28 = fcmp ord <4 x double> %wide.load, %wide.load40
    %29 = fcmp ord <4 x double> %wide.load, zeroinitializer
    %30 = select <4 x i1> %29, <4 x double> %wide.load40, <4 x double> %wide.load
    %31 = fcmp olt <4 x double> %wide.load, %wide.load40
    %32 = bitcast <4 x double> %wide.load to <4 x i64>
    %33 = bitcast <4 x double> %wide.load40 to <4 x i64>
    %34 = icmp sgt <4 x i64> %33, <i64 -1, i64 -1, i64 -1, i64 -1>
    %35 = icmp slt <4 x i64> %32, zeroinitializer
    %36 = and <4 x i1> %35, %34
    %37 = or <4 x i1> %31, %36
    %38 = select <4 x i1> %37, <4 x double> %wide.load, <4 x double> %wide.load40
    %39 = select <4 x i1> %37, <4 x double> %wide.load40, <4 x double> %wide.load
    %40 = select <4 x i1> %28, <4 x double> %38, <4 x double> %30
    %41 = select <4 x i1> %28, <4 x double> %39, <4 x double> %30
    %42 = getelementptr inbounds double, double* %22, i64 %index
    %43 = bitcast double* %42 to <4 x double>*
    store <4 x double> %40, <4 x double>* %43, align 8
    %44 = getelementptr inbounds double, double* %23, i64 %index
    %45 = bitcast double* %44 to <4 x double>*
    store <4 x double> %41, <4 x double>* %45, align 8
    %index.next = add i64 %index, 4
    %46 = icmp eq i64 %index.next, %n.vec
    br i1 %46, label %middle.block, label %vector.body

, which is almost the same as minmax_nb.
Maybe ifelse should retire?

2 Likes

I checked the performance of your version with @Elrod’s benchmark. Indeed, the branching version is much faster. I made the numbers normally distributed and added in some NaN values as well to make the benchmark also test the signbit and isnan parts.

Here are the two versions

minmax_base(x::T, y::T) where {T<:AbstractFloat} =
    ifelse(isnan(x) | isnan(y), ifelse(isnan(x), (x,x), (y,y)),
           ifelse((y > x) | (signbit(x) > signbit(y)), (x,y), (y,x)))

function minmax_branch(x::T, y::T) where {T<:AbstractFloat}
    if isnan(x) | isnan(y)
        isnan(x) ? (x, x) : (y, y)
    else
        (y > x) | (signbit(x) > signbit(y)) ? (x, y) : (y, x)
    end
end

I also tested a short circuiting version

function minmax_short_circuit(x::T, y::T) where {T<:AbstractFloat}
    if isnan(x) || isnan(y)
        isnan(x) ? (x, x) : (y, y)
    else
        (y > x) || (signbit(x) > signbit(y)) ? (x, y) : (y, x)
    end
end

And here are the benchmarks

function vminmax_map!(f::F, mi, ma, x, y) where {F}
    @inbounds for i in eachindex(mi, ma, x, y)
        mi[i], ma[i] = f(x[i], y[i])
    end
end

let x = randn(400), y = randn(400), mi = similar(x), ma = similar(x)
    x[abs2.(x) .> 2] .= NaN
    y[abs2.(y) .> 2] .= NaN

    @show count(isnan, x)
    @show count(isnan, y)

    for f in [minmax_base, minmax_branch, minmax_short_circuit]
        println(f, "\n")
        b = @benchmark vminmax_map!($f, $mi, $ma, $x, $y)
        display(b)
        println("\n")
    end
end

with the following results on my machine

count(isnan, x) = 62
count(isnan, y) = 48
minmax_base

BechmarkTools.Trial: 10000 samples with 68 evaluations.
 Range (min … max):  846.029 ns …   3.366 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     874.662 ns               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   912.077 ns Β± 148.089 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–‚β–„β–ˆβ–ƒβ–                                                         ▁
  β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–ˆβ–‡β–ˆβ–‡β–ˆβ–†β–‡β–ˆβ–†β–…β–†β–…β–†β–†β–†β–†β–…β–…β–†β–…β–…β–†β–†β–…β–†β–†β–…β–…β–…β–…β–†β–…β–…β–†β–…β–…β–…β–…β–…β–„β–ƒβ–„β–„β–…β–„β–ƒβ–…β–ƒβ–ƒβ–„β–‚β–…β–… β–ˆ
  846 ns        Histogram: log(frequency) by time       1.63 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.


minmax_branch

BechmarkTools.Trial: 10000 samples with 425 evaluations.
 Range (min … max):  230.165 ns … 657.605 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     237.214 ns               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   241.901 ns Β±  25.652 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–„  β–ˆβ–„β–‚β–‚β–β–                                                     ▁
  β–ˆβ–‡β–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–†β–ˆβ–‡β–ˆβ–†β–†β–‡β–‡β–†β–†β–‡β–‡β–‡β–‡β–†β–†β–†β–…β–†β–…β–†β–…β–†β–…β–„β–†β–†β–…β–†β–†β–…β–…β–†β–…β–…β–…β–„β–…β–…β–„β–…β–…β–…β–„β–„β–β–„β–ƒβ–„β–„ β–ˆ
  230 ns        Histogram: log(frequency) by time        350 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.


minmax_short_circuit

BechmarkTools.Trial: 10000 samples with 352 evaluations.
 Range (min … max):  256.074 ns … 621.827 ns  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     263.514 ns               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   268.433 ns Β±  24.659 ns  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

  β–†  β–ˆβ–‚β–β–                                                       ▁
  β–ˆβ–‡β–„β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–‡β–†β–…β–‡β–†β–…β–†β–†β–†β–†β–†β–‡β–†β–†β–‡β–†β–‡β–…β–…β–†β–†β–†β–…β–†β–†β–…β–†β–…β–…β–…β–…β–„β–„β–ƒβ–†β–†β–†β–…β–…β–…β–ƒβ–…β–…β–„β–„β–ƒβ–„β–…β–ƒβ–ƒβ–„ β–ˆ
  256 ns        Histogram: log(frequency) by time        393 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.
1 Like

I think your benchmark is skewed by the fact that you’re always using the same data. The branch predictor in your CPU is probably picking up the patterns in your data and messing with your result. I’d create a new array for each benchmark run using the setup keyword of @benchmark. It should still even out statistically, but it’ll make it much harder for the branch predictor to mess with things.

5 Likes

Thanks, though it doesn’t seem to make a real difference in this case, at least when implemented as follows.

function bench_setup(N, nan_cut)
    x = randn(N)
    y = randn(N)

    x[abs2.(x) .> nan_cut] .= NaN
    y[abs2.(y) .> nan_cut] .= NaN

    return x, y, similar(x), similar(y)
end

for f in [minmax_base, minmax_branch, minmax_short_circuit]
    println(f, "\n")
    b = @benchmark vminmax_map!($f, mi, ma, x, y) setup=begin
        x, y, mi, ma = bench_setup(400, 2)
    end

    display(b)
    println("\n")
end

Correction! I should have set evals per sample to 1 and increased the length so I still get accurate timing. When I do that, I get

minmax_base

BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):  39.834 ΞΌs … 342.509 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     72.480 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   75.199 ΞΌs Β±  12.760 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

   β–‚                  β–ƒβ–†β–ˆβ–ˆβ–‡β–†β–†β–†β–…β–„β–„β–ƒβ–ƒβ–ƒβ–‚β–ƒβ–‚β–‚β–β–β–β–β–                  β–ƒ
  β–ˆβ–ˆβ–‡β–†β–„β–„β–…β–ƒβ–„β–„β–…β–„β–„β–…β–…β–ƒβ–β–„β–…β–‡β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‡β–‡β–ˆβ–ˆβ–ˆβ–‡β–‡β–‡β–†β–‡β–†β–…β–†β–‡β–†β–… β–ˆ
  39.8 ΞΌs       Histogram: log(frequency) by time       121 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.


minmax_branch

BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):   4.293 ΞΌs … 278.540 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     33.580 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   35.297 ΞΌs Β±   9.275 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

                            β–„β–ˆβ–‡β–‚                                
  β–ƒβ–‚β–‚β–‚β–‚β–‚β–β–β–β–β–β–β–β–‚β–‚β–β–‚β–‚β–‚β–‚β–β–‚β–‚β–‚β–ƒβ–†β–ˆβ–ˆβ–ˆβ–ˆβ–‡β–…β–…β–…β–„β–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚ β–ƒ
  4.29 ΞΌs         Histogram: frequency by time         65.8 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.


minmax_short_circuit

BechmarkTools.Trial: 10000 samples with 1 evaluations.
 Range (min … max):   4.202 ΞΌs … 269.420 ΞΌs  β”Š GC (min … max): 0.00% … 0.00%
 Time  (median):     33.655 ΞΌs               β”Š GC (median):    0.00%
 Time  (mean Β± Οƒ):   35.367 ΞΌs Β±   9.159 ΞΌs  β”Š GC (mean Β± Οƒ):  0.00% Β± 0.00%

                            β–‚β–ˆβ–ˆβ–„                                
  β–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–β–‚β–β–β–β–‚β–β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–ƒβ–…β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–†β–†β–…β–„β–„β–ƒβ–ƒβ–ƒβ–ƒβ–ƒβ–‚β–ƒβ–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚β–‚ β–ƒ
  4.2 ΞΌs          Histogram: frequency by time         65.3 ΞΌs <

 Memory estimate: 0 bytes, allocs estimate: 0.

with

for f in [minmax_base, minmax_branch, minmax_short_circuit]
    println(f, "\n")
    b = @benchmark vminmax_map!($f, mi, ma, x, y) setup=begin
        x, y, mi, ma = bench_setup(5_000, 2)
    end evals=1

    display(b)
    println("\n")
end

TLDR: branching is faster. Short circuiting may or may not be slower in a real world use case?

4 Likes

So, I got curious about this because I was writing a function to sort three numbers. Having now tested what native instructions I get, it turns out the choice of minmax implementation has no effect whatsoever on the final compiled code for this example.

Here are the three sorting functions

Sorting Functions
function sort3(a, b, c)
    l1, h1 = minmax(a, b)
    lo, h2 = minmax(l1, c)
    md, hi = minmax(h1, h2)
    return lo, md, hi
end

@inline minmax_nb1(x,y) = (xlty = x < y; (ifelse(xlty, x, y), ifelse(xlty, y, x)))

function sort3_nb1(a, b, c)
    l1, h1 = minmax_nb1(a, b)
    lo, h2 = minmax_nb1(l1, c)
    md, hi = minmax_nb1(h1, h2)
    return lo, md, hi
end

@inline minmax_nb2(x, y) = ifelse(isless(x, y), (x, y), (y, x))

function sort3_nb2(a, b, c)
    l1, h1 = minmax_nb2(a, b)
    lo, h2 = minmax_nb2(l1, c)
    md, hi = minmax_nb2(h1, h2)
    return lo, md, hi
end

Here are the native assembly instructions for each

Code native
julia> @code_native sort3(3, 2, 1)
        .section        __TEXT,__text,regular,pure_instructions
; β”Œ @ sort3.jl:10 within `sort3'
        movq    %rdi, %rax
; β”‚ @ sort3.jl:11 within `sort3'
; β”‚β”Œ @ promotion.jl:423 within `minmax'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rsi, %rdx
; β”‚β”‚β””
        movq    %rsi, %rdi
        cmovlq  %rdx, %rdi
        cmovlq  %rsi, %rdx
; β”‚β””
; β”‚ @ sort3.jl:12 within `sort3'
; β”‚β”Œ @ promotion.jl:423 within `minmax'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rcx, %rdi
; β”‚β”‚β””
        movq    %rdi, %rsi
        cmovgq  %rcx, %rsi
        cmovleq %rcx, %rdi
; β”‚β””
; β”‚ @ sort3.jl:13 within `sort3'
; β”‚β”Œ @ promotion.jl:423 within `minmax'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rdx, %rdi
; β”‚β”‚β””
        movq    %rdx, %rcx
        cmovlq  %rdi, %rcx
        cmovlq  %rdx, %rdi
; β”‚β””
; β”‚ @ sort3.jl:14 within `sort3'
        movq    %rsi, (%rax)
        movq    %rcx, 8(%rax)
        movq    %rdi, 16(%rax)
        retq
        nopl    (%rax)
; β””

julia> @code_native sort3_nb1(3, 2, 1)
        .section        __TEXT,__text,regular,pure_instructions
; β”Œ @ sort3.jl:19 within `sort3_nb1'
        movq    %rdi, %rax
; β”‚ @ sort3.jl:20 within `sort3_nb1'
; β”‚β”Œ @ sort3.jl:17 within `minmax_nb1'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rdx, %rsi
; β”‚β”‚β””
        movq    %rdx, %rdi
        cmovlq  %rsi, %rdi
        cmovlq  %rdx, %rsi
; β”‚β””
; β”‚ @ sort3.jl:21 within `sort3_nb1'
; β”‚β”Œ @ sort3.jl:17 within `minmax_nb1'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rcx, %rdi
; β”‚β”‚β””
        movq    %rcx, %rdx
        cmovlq  %rdi, %rdx
        cmovlq  %rcx, %rdi
; β”‚β””
; β”‚ @ sort3.jl:22 within `sort3_nb1'
; β”‚β”Œ @ sort3.jl:17 within `minmax_nb1'
; β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rdi, %rsi
; β”‚β”‚β””
        movq    %rdi, %rcx
        cmovlq  %rsi, %rcx
        cmovgeq %rsi, %rdi
; β”‚β””
; β”‚ @ sort3.jl:23 within `sort3_nb1'
        movq    %rdx, (%rax)
        movq    %rcx, 8(%rax)
        movq    %rdi, 16(%rax)
        retq
        nopl    (%rax)
; β””

julia> @code_native sort3_nb2(3, 2, 1)
        .section        __TEXT,__text,regular,pure_instructions
; β”Œ @ sort3.jl:28 within `sort3_nb2'
        movq    %rdi, %rax
; β”‚ @ sort3.jl:29 within `sort3_nb2'
; β”‚β”Œ @ sort3.jl:26 within `minmax_nb2'
; β”‚β”‚β”Œ @ operators.jl:357 within `isless'
; β”‚β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rdx, %rsi
; β”‚β”‚β””β””
        movq    %rsi, %rdi
        cmovlq  %rdx, %rdi
        cmovlq  %rsi, %rdx
; β”‚β””
; β”‚ @ sort3.jl:30 within `sort3_nb2'
; β”‚β”Œ @ sort3.jl:26 within `minmax_nb2'
; β”‚β”‚β”Œ @ operators.jl:357 within `isless'
; β”‚β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rcx, %rdx
; β”‚β”‚β””β””
        movq    %rcx, %rsi
        cmovlq  %rdx, %rsi
        cmovlq  %rcx, %rdx
; β”‚β””
; β”‚ @ sort3.jl:31 within `sort3_nb2'
; β”‚β”Œ @ sort3.jl:26 within `minmax_nb2'
; β”‚β”‚β”Œ @ operators.jl:357 within `isless'
; β”‚β”‚β”‚β”Œ @ int.jl:83 within `<'
        cmpq    %rdx, %rdi
; β”‚β”‚β””β””
        movq    %rdi, %rcx
        cmovlq  %rdx, %rcx
        cmovlq  %rdi, %rdx
; β”‚β””
; β”‚ @ sort3.jl:32 within `sort3_nb2'
        movq    %rsi, (%rax)
        movq    %rdx, 8(%rax)
        movq    %rcx, 16(%rax)
        retq
        nopl    (%rax)
; β””
1 Like

There’s plenty of use cases for a ifelse function, because you can dispatch on it.
Except, you can’t add methods to Base.ifelse. But IfElse.jl is used by packages like Symbolics.jl and LoopVectorization.jl.
So I’m not a particularly big fan of Base.ifelse in its current form.

@Luapulu @Sukera the branch version is faster not because branches are faster, but because it is easier for the compiler to understand than using ifelse in the way that the base function does.
The compiler gets rid of the branches in optimizing it.
Note that

 function minmax_nb(x::T, y::T) where {T<:AbstractFloat}
    isnanx = isnan(x)
    isnany = isnan(y)
    ygx = y > x
    sbxgsby = signbit(x) > signbit(y)
    l = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, x, y))
    u = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, y, x))
    l, u
end

is just as fast as the branch versions, and much faster than base.
The compiler doesn’t really understand ifelseing tuples, like ifelse(b, (x,y), (y,x)), so the way the base method is written causes the optimizer to fail.
The branch version is understood, and so is the above one without tuples in ifelse, so both of these are optimized in the loop, and neither of them actually have branches post-optimization (the branches get converted into ifelse statements like the above).

For reference, this is the loop body on my computer:

vector.body:                                      ; preds = %vector.body, %vector.ph
  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
  %108 = getelementptr inbounds double, double addrspace(13)* %98, i64 %index, !dbg !102
  %109 = bitcast double addrspace(13)* %108 to <8 x double> addrspace(13)*, !dbg !102
  %wide.load = load <8 x double>, <8 x double> addrspace(13)* %109, align 8, !dbg !102, !tbaa !103, !alias.scope !105
  %110 = getelementptr inbounds double, double addrspace(13)* %108, i64 8, !dbg !102
  %111 = bitcast double addrspace(13)* %110 to <8 x double> addrspace(13)*, !dbg !102
  %wide.load78 = load <8 x double>, <8 x double> addrspace(13)* %111, align 8, !dbg !102, !tbaa !103, !alias.scope !105
  %112 = getelementptr inbounds double, double addrspace(13)* %101, i64 %index, !dbg !102
  %113 = bitcast double addrspace(13)* %112 to <8 x double> addrspace(13)*, !dbg !102
  %wide.load79 = load <8 x double>, <8 x double> addrspace(13)* %113, align 8, !dbg !102, !tbaa !103, !alias.scope !108
  %114 = getelementptr inbounds double, double addrspace(13)* %112, i64 8, !dbg !102
  %115 = bitcast double addrspace(13)* %114 to <8 x double> addrspace(13)*, !dbg !102
  %wide.load80 = load <8 x double>, <8 x double> addrspace(13)* %115, align 8, !dbg !102, !tbaa !103, !alias.scope !108
  %116 = fcmp ord <8 x double> %wide.load, %wide.load79, !dbg !110
  %117 = fcmp ord <8 x double> %wide.load78, %wide.load80, !dbg !110
  %118 = fcmp ord <8 x double> %wide.load, zeroinitializer, !dbg !112
  %119 = fcmp ord <8 x double> %wide.load78, zeroinitializer, !dbg !112
  %120 = select <8 x i1> %118, <8 x double> %wide.load79, <8 x double> %wide.load, !dbg !90
  %121 = select <8 x i1> %119, <8 x double> %wide.load80, <8 x double> %wide.load78, !dbg !90
  %122 = fcmp olt <8 x double> %wide.load, %wide.load79, !dbg !118
  %123 = fcmp olt <8 x double> %wide.load78, %wide.load80, !dbg !118
  %124 = bitcast <8 x double> %wide.load to <8 x i64>, !dbg !122
  %125 = bitcast <8 x double> %wide.load78 to <8 x i64>, !dbg !122
  %126 = bitcast <8 x double> %wide.load79 to <8 x i64>, !dbg !122
  %127 = bitcast <8 x double> %wide.load80 to <8 x i64>, !dbg !122
  %128 = icmp sgt <8 x i64> %126, <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, !dbg !125
  %129 = icmp sgt <8 x i64> %127, <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, !dbg !125
  %130 = icmp slt <8 x i64> %124, zeroinitializer, !dbg !129
  %131 = icmp slt <8 x i64> %125, zeroinitializer, !dbg !129
  %132 = and <8 x i1> %130, %128, !dbg !129
  %133 = and <8 x i1> %131, %129, !dbg !129
  %134 = or <8 x i1> %122, %132, !dbg !130
  %135 = or <8 x i1> %123, %133, !dbg !130
  %136 = select <8 x i1> %134, <8 x double> %wide.load, <8 x double> %wide.load79, !dbg !90
  %137 = select <8 x i1> %135, <8 x double> %wide.load78, <8 x double> %wide.load80, !dbg !90
  %138 = select <8 x i1> %134, <8 x double> %wide.load79, <8 x double> %wide.load, !dbg !90
  %139 = select <8 x i1> %135, <8 x double> %wide.load80, <8 x double> %wide.load78, !dbg !90
  %predphi = select <8 x i1> %116, <8 x double> %136, <8 x double> %120
  %predphi81 = select <8 x i1> %117, <8 x double> %137, <8 x double> %121
  %predphi82 = select <8 x i1> %116, <8 x double> %138, <8 x double> %120
  %predphi83 = select <8 x i1> %117, <8 x double> %139, <8 x double> %121
  %140 = getelementptr inbounds double, double addrspace(13)* %104, i64 %index, !dbg !131
  %141 = bitcast double addrspace(13)* %140 to <8 x double> addrspace(13)*, !dbg !131
  store <8 x double> %predphi, <8 x double> addrspace(13)* %141, align 8, !dbg !131, !tbaa !103, !alias.scope !132, !noalias !134
  %142 = getelementptr inbounds double, double addrspace(13)* %140, i64 8, !dbg !131
  %143 = bitcast double addrspace(13)* %142 to <8 x double> addrspace(13)*, !dbg !131
  store <8 x double> %predphi81, <8 x double> addrspace(13)* %143, align 8, !dbg !131, !tbaa !103, !alias.scope !132, !noalias !134
  %144 = getelementptr inbounds double, double addrspace(13)* %107, i64 %index, !dbg !131
  %145 = bitcast double addrspace(13)* %144 to <8 x double> addrspace(13)*, !dbg !131
  store <8 x double> %predphi82, <8 x double> addrspace(13)* %145, align 8, !dbg !131, !tbaa !103, !alias.scope !136, !noalias !137
  %146 = getelementptr inbounds double, double addrspace(13)* %144, i64 8, !dbg !131
  %147 = bitcast double addrspace(13)* %146 to <8 x double> addrspace(13)*, !dbg !131
  store <8 x double> %predphi83, <8 x double> addrspace(13)* %147, align 8, !dbg !131, !tbaa !103, !alias.scope !136, !noalias !137
  %index.next = add i64 %index, 16
  %148 = icmp eq i64 %index.next, %n.vec
  br i1 %148, label %middle.block, label %vector.body, !llvm.loop !138

Note the lack of branches (it is one solid block of instructions), and all the select statements.
This was @anon56330260’s point: the compiler understands branches, and thus writing your code with branches is a good choice to let the compiler decide whether or not to actually use a branch.

If you’re careful to only use ifelse with primitive types like Float64 (instead of Tuple{Float64,Float64}), you’ll be fine with it. But then multuple ifelse statements are more verbose than ?.

3 Likes

So, what should the future of Base.ifelse look like? Ideally, it would be possible to tell the compiler more explicitly that something ought to be branchless, rather than relying on it to optimise away branches, no? Or would the plan be to get rid of ifelse and rely on the compiler to optimise away branches if needed?

And btw, why can’t you add methods to Base.ifelse?

Because it is a builtin function.

julia> Base.ifelse(b::Bool, x::Tuple{Vararg{Number,K}}, y::Tuple{Vararg{Number,K}}) where {K} = map((x,y) -> ifelse(b, x, y), x, y)
ERROR: cannot add methods to a builtin function
Stacktrace:
 [1] top-level scope
   @ REPL[5]:1

julia> using IfElse

julia> IfElse.ifelse(b::Bool, x::Tuple{Vararg{Number,K}}, y::Tuple{Vararg{Number,K}}) where {K} = map((x,y) -> Base.ifelse(b, x, y), x, y)

The future I’d like to see is (a) make it generic, and (b) add a method like the above to workaround the performance issue discussed in this thread and that I linked to earlier.

4 Likes

Yes, that’s why we have SIMD library. You can hardcode SIMD code by yourself, which is not a good idea definitely. That’s why we have vectorization pass in compiler to do this automatically. You might ask whether it’s possible to give some hints to compiler that the code should be branchless. But the problem is that there are many ways to be branchless. You might indicate a suboptimal solution, just like the minmax example (it’s branchless, but not vectorized).

Instead, I think we should just define Base.ifelse only on a constraint set of types (Int,Float64,NTuple{4,Int}…). Base.ifelse acts like vifelse in the SIMD libraries. It provides no further function than ? : and is used only for optimization. It’s not generic by design (it lowers to select and not all the values can be the operands of select). Misuse of this function can hurt performance, since Julia’s frontend compiler isn’t aware of this function (thus does no special optimization) and lowering to select early might block further vectorization, as seen by the minmax example.

Looking at your benchmark results, I would conclude that branch and short circuit are the same, and one would need a different reason to prefer one over the other.

1 Like