minmax for two numbers is defined by
My question is: why not make it branchless with
minmax(x, y) = ifelse(isless(x, y), (x, y), (y, x))
minmax for two numbers is defined by
My question is: why not make it branchless with
minmax(x, y) = ifelse(isless(x, y), (x, y), (y, x))
Do you have a specific example of the difference?
Edit: I always trust experimental facts over my memory.
We know this will stop SIMD from happening (even without an example) right? (I forgot if currently ifelse
can be SIMD or not)
I recently heard the compiler can remove simple branches. Maybe this is one of those cases?
The generated machine code is indeed different if use solely.
With branch (the version in the Base
):
@code_native minmax(1,2)
.text
; β @ promotion.jl:423 within `minmax'
movq %rdi, %rax
; ββ @ int.jl:83 within `<'
cmpq %rsi, %rdx
; ββ
jge L16
movq %rdx, (%rax)
movq %rsi, 8(%rax)
retq
L16:
movq %rsi, (%rax)
movq %rdx, 8(%rax)
retq
nopl (%rax,%rax)
; β
Without branch:
julia> myminmax(x, y) = ifelse(isless(x, y), (x, y), (y, x))
myminmax (generic function with 1 method)
julia> @code_native myminmax(1,2)
.text
; β @ REPL[2]:1 within `myminmax'
movq %rdi, %rax
; ββ @ operators.jl:357 within `isless'
; βββ @ int.jl:83 within `<'
cmpq %rdx, %rsi
; βββ
movq %rsi, %rcx
cmovlq %rdx, %rcx
cmovlq %rsi, %rdx
movq %rdx, (%rdi)
movq %rcx, 8(%rdi)
retq
nopl (%rax)
; β
So the version without branch has more move
instruction.
But they can be the same if they get inlined into other functions.
julia> function swap(x::Ref)
x[] = minmax(x[]..)
nothing
end
julia> @code_native swap(Ref((1,2)))
.text
; β @ REPL[8]:1 within `swap'
movq %rsi, -8(%rsp)
movq (%rsi), %rax
; β @ REPL[8]:2 within `swap'
; ββ @ refvalue.jl:56 within `getindex'
; βββ @ Base.jl:33 within `getproperty'
movq (%rax), %rcx
movq 8(%rax), %rdx
; βββ
; ββ @ promotion.jl:423 within `minmax'
; βββ @ int.jl:83 within `<'
cmpq %rcx, %rdx
; βββ
movq %rcx, %rsi
cmovlq %rdx, %rsi
cmovlq %rcx, %rdx
; ββ
; ββ @ refvalue.jl:57 within `setindex!'
; βββ @ Base.jl:34 within `setproperty!'
movq %rsi, (%rax)
movq %rdx, 8(%rax)
movabsq $jl_system_image_data, %rax
; βββ
; β @ REPL[8]:3 within `swap'
retq
; β
Some even more interesting observations here:
function minmax1(x,y)
l = isless(y, x) ? (y, x) : (x, y)
return l
end
minmax1
is essentially equivalent to minmax
, just add a explicit return clause here. You get:
julia> @code_native minmax1(1,2)
.text
; β @ REPL[22]:1 within `minmax1'
movq %rdi, %rax
; β @ REPL[22]:2 within `minmax1'
; ββ @ operators.jl:357 within `isless'
; βββ @ int.jl:83 within `<'
cmpq %rsi, %rdx
; βββ
movq %rsi, %rcx
cmovlq %rdx, %rcx
cmovlq %rsi, %rdx
; β @ REPL[22]:3 within `minmax1'
movq %rcx, (%rdi)
movq %rdx, 8(%rdi)
retq
nopl (%rax)
; β
, which is the same as the branchless version. What happens here is that minmax(x,y) = isless(y, x) ? (y, x) : (x, y)
has two return statements (one for each branch) and minmax1
has only one. I guess itβs easier for compiler to optimize the latter case.
In summary, minmax
with or without branch are almost the same. Compiler can replace the branch with a select (then lower to a conditional move). I think itβs better to just let the compiler do the choice hereβ¦
The compiler seems to be failing.
Also, you should define a branchless minmax
like this:
@inline minmax_nb(x,y) = (xlty = x < y; (ifelse(xlty, x, y), ifelse(xlty, y, x)))
(optionally, add a using IfElse: ifelse
).
This is much easier for the compiler to understand than
@inline minmax_nb2(x,y) = (xlty = x < y; ifelse(xlty, (x,y), (y,x)))
Here is the associated Julia issue.
Testing how well each of the versions performs:
julia> function vminmax_map!(f::F, mi, ma, x, y) where {F}
@inbounds for i β eachindex(mi,ma,x,y)
mi[i], ma[i] = f(x[i], y[i])
end
end
vminmax_map! (generic function with 1 method)
julia> x = rand(400); y = rand(400); mi = similar(x); ma = similar(x);
julia> x = rand(Int,400); y = rand(Int,400); mi = similar(x); ma = similar(x);
julia> @benchmark vminmax_map!(minmax, $mi, $ma, $x, $y)
BechmarkTools.Trial: 10000 samples with 994 evaluations.
Range (min β¦ max): 29.058 ns β¦ 122.268 ns β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 29.180 ns β GC (median): 0.00%
Time (mean Β± Ο): 29.570 ns Β± 1.331 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
β
ββ βββ β ββ β
βββββββββββββββββββ
ββ
βββββββββ
ββ
β
β
ββββββββ
ββ
β
ββββββββββββββ
β β
29.1 ns Histogram: log(frequency) by time 32.9 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark vminmax_map!(minmax_nb2, $mi, $ma, $x, $y)
BechmarkTools.Trial: 10000 samples with 585 evaluations.
Range (min β¦ max): 202.720 ns β¦ 1.593 ΞΌs β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 206.496 ns β GC (median): 0.00%
Time (mean Β± Ο): 209.327 ns Β± 35.243 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
β
βββββ
β βββ β
βββββββββββββ
ββ
βββ
β
β
β
β
ββ
βββββ
ββ
β
β
β
β
ββ
β
ββ
ββ
ββ
ββ
ββββββββ
ββββββ β
203 ns Histogram: log(frequency) by time 238 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
julia> @benchmark vminmax_map!(minmax_nb, $mi, $ma, $x, $y)
BechmarkTools.Trial: 10000 samples with 994 evaluations.
Range (min β¦ max): 29.075 ns β¦ 60.479 ns β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 29.767 ns β GC (median): 0.00%
Time (mean Β± Ο): 29.849 ns Β± 1.466 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
ββ ββββ βββ ββ β
ββββββββββββββββββββββββ
β
β
ββ
β
β
β
ββ
βββββββββββ
βββββββββββ
β
β
ββ
β
29.1 ns Histogram: log(frequency) by time 34.3 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
So the compiler fails to optimize unless weβre careful with our implementation. minmax_nb
(with two ifelse
statements) was 5x faster than a single ifelse
statement.
The branch version is equally fast as the two ifelse
s.
BTW, in the case of floating point numbers, I donβt think the comparison is fair.
julia> minmax_nb(0.0, NaN)
(NaN, 0.0)
julia> minmax_nb(-0.0, 0.0)
(0.0, -0.0)
Post an issue?
Oops, youβre right. I edited my above comment.
This would be the correct floating point definition:
function minmax_nb(x::T, y::T) where {T<:AbstractFloat}
isnanx = isnan(x)
isnany = isnan(y)
ygx = y > x
sbxgsby = signbit(x) > signbit(y)
l = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, x, y))
u = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, y, x))
l, u
end
but as the base floating point defintion uses ifelse
instead of branches, Iβll have to correct my above post. I shouldβve checked @less minmax(1.0,2.0)
to confirm itβs actually using branches (it isnβt).
Iβll try integers, where LLVM should hopefully be able to convert branches to select
β¦
So I guess the conclusion here is that branch is better than ifelse? Since LLVM can already perform if-conversion, so use of branch should not be problematic (at least in the simple case). In contrast, Juliaβs frontend canβt correctly remove use of ifelse
.
Actually, if we check the code of @less minmax(1.0,2.0)
. And replace all the ifelse
by branch, we get the vectorized code:
#myminmax is the same as minmax, except that ifelse is replaced by a branch.
@inline myifelse(b,x,y) = b ? x : y
myminmax(x::T, y::T) where {T<:AbstractFloat} =
myifelse(isnan(x) | isnan(y), myifelse(isnan(x), (x,x), (y,y)),
myifelse((y > x) | (signbit(x) > signbit(y)), (x,y), (y,x)))
We get:
vector.body:
%index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
%24 = getelementptr inbounds double, double* %20, i64 %index
%25 = bitcast double* %24 to <4 x double>*
%wide.load = load <4 x double>, <4 x double>* %25, align 8
%26 = getelementptr inbounds double, double* %21, i64 %index
%27 = bitcast double* %26 to <4 x double>*
%wide.load40 = load <4 x double>, <4 x double>* %27, align 8
%28 = fcmp ord <4 x double> %wide.load, %wide.load40
%29 = fcmp ord <4 x double> %wide.load, zeroinitializer
%30 = select <4 x i1> %29, <4 x double> %wide.load40, <4 x double> %wide.load
%31 = fcmp olt <4 x double> %wide.load, %wide.load40
%32 = bitcast <4 x double> %wide.load to <4 x i64>
%33 = bitcast <4 x double> %wide.load40 to <4 x i64>
%34 = icmp sgt <4 x i64> %33, <i64 -1, i64 -1, i64 -1, i64 -1>
%35 = icmp slt <4 x i64> %32, zeroinitializer
%36 = and <4 x i1> %35, %34
%37 = or <4 x i1> %31, %36
%38 = select <4 x i1> %37, <4 x double> %wide.load, <4 x double> %wide.load40
%39 = select <4 x i1> %37, <4 x double> %wide.load40, <4 x double> %wide.load
%40 = select <4 x i1> %28, <4 x double> %38, <4 x double> %30
%41 = select <4 x i1> %28, <4 x double> %39, <4 x double> %30
%42 = getelementptr inbounds double, double* %22, i64 %index
%43 = bitcast double* %42 to <4 x double>*
store <4 x double> %40, <4 x double>* %43, align 8
%44 = getelementptr inbounds double, double* %23, i64 %index
%45 = bitcast double* %44 to <4 x double>*
store <4 x double> %41, <4 x double>* %45, align 8
%index.next = add i64 %index, 4
%46 = icmp eq i64 %index.next, %n.vec
br i1 %46, label %middle.block, label %vector.body
, which is almost the same as minmax_nb
.
Maybe ifelse
should retire?
I checked the performance of your version with @Elrodβs benchmark. Indeed, the branching version is much faster. I made the numbers normally distributed and added in some NaN
values as well to make the benchmark also test the signbit
and isnan
parts.
Here are the two versions
minmax_base(x::T, y::T) where {T<:AbstractFloat} =
ifelse(isnan(x) | isnan(y), ifelse(isnan(x), (x,x), (y,y)),
ifelse((y > x) | (signbit(x) > signbit(y)), (x,y), (y,x)))
function minmax_branch(x::T, y::T) where {T<:AbstractFloat}
if isnan(x) | isnan(y)
isnan(x) ? (x, x) : (y, y)
else
(y > x) | (signbit(x) > signbit(y)) ? (x, y) : (y, x)
end
end
I also tested a short circuiting version
function minmax_short_circuit(x::T, y::T) where {T<:AbstractFloat}
if isnan(x) || isnan(y)
isnan(x) ? (x, x) : (y, y)
else
(y > x) || (signbit(x) > signbit(y)) ? (x, y) : (y, x)
end
end
And here are the benchmarks
function vminmax_map!(f::F, mi, ma, x, y) where {F}
@inbounds for i in eachindex(mi, ma, x, y)
mi[i], ma[i] = f(x[i], y[i])
end
end
let x = randn(400), y = randn(400), mi = similar(x), ma = similar(x)
x[abs2.(x) .> 2] .= NaN
y[abs2.(y) .> 2] .= NaN
@show count(isnan, x)
@show count(isnan, y)
for f in [minmax_base, minmax_branch, minmax_short_circuit]
println(f, "\n")
b = @benchmark vminmax_map!($f, $mi, $ma, $x, $y)
display(b)
println("\n")
end
end
with the following results on my machine
count(isnan, x) = 62
count(isnan, y) = 48
minmax_base
BechmarkTools.Trial: 10000 samples with 68 evaluations.
Range (min β¦ max): 846.029 ns β¦ 3.366 ΞΌs β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 874.662 ns β GC (median): 0.00%
Time (mean Β± Ο): 912.077 ns Β± 148.089 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
βββββ β
ββββββββββββββββββ
ββ
βββββ
β
ββ
β
βββ
βββ
β
β
β
ββ
β
ββ
β
β
β
β
βββββ
βββ
βββββ
β
β
846 ns Histogram: log(frequency) by time 1.63 ΞΌs <
Memory estimate: 0 bytes, allocs estimate: 0.
minmax_branch
BechmarkTools.Trial: 10000 samples with 425 evaluations.
Range (min β¦ max): 230.165 ns β¦ 657.605 ns β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 237.214 ns β GC (median): 0.00%
Time (mean Β± Ο): 241.901 ns Β± 25.652 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
β ββββββ β
βββ
ββββββββββββββββββββββββββ
ββ
ββ
ββ
ββββ
βββ
β
ββ
β
β
ββ
β
ββ
β
β
βββββββ β
230 ns Histogram: log(frequency) by time 350 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
minmax_short_circuit
BechmarkTools.Trial: 10000 samples with 352 evaluations.
Range (min β¦ max): 256.074 ns β¦ 621.827 ns β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 263.514 ns β GC (median): 0.00%
Time (mean Β± Ο): 268.433 ns Β± 24.659 ns β GC (mean Β± Ο): 0.00% Β± 0.00%
β ββββ β
βββββββββββββ
βββ
ββββββββββββ
β
ββββ
βββ
ββ
β
β
β
βββββββ
β
β
ββ
β
βββββ
βββ β
256 ns Histogram: log(frequency) by time 393 ns <
Memory estimate: 0 bytes, allocs estimate: 0.
I think your benchmark is skewed by the fact that youβre always using the same data. The branch predictor in your CPU is probably picking up the patterns in your data and messing with your result. Iβd create a new array for each benchmark run using the setup
keyword of @benchmark
. It should still even out statistically, but itβll make it much harder for the branch predictor to mess with things.
Thanks, though it doesnβt seem to make a real difference in this case, at least when implemented as follows.
function bench_setup(N, nan_cut)
x = randn(N)
y = randn(N)
x[abs2.(x) .> nan_cut] .= NaN
y[abs2.(y) .> nan_cut] .= NaN
return x, y, similar(x), similar(y)
end
for f in [minmax_base, minmax_branch, minmax_short_circuit]
println(f, "\n")
b = @benchmark vminmax_map!($f, mi, ma, x, y) setup=begin
x, y, mi, ma = bench_setup(400, 2)
end
display(b)
println("\n")
end
Correction! I should have set evals per sample to 1 and increased the length so I still get accurate timing. When I do that, I get
minmax_base
BechmarkTools.Trial: 10000 samples with 1 evaluations.
Range (min β¦ max): 39.834 ΞΌs β¦ 342.509 ΞΌs β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 72.480 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 75.199 ΞΌs Β± 12.760 ΞΌs β GC (mean Β± Ο): 0.00% Β± 0.00%
β βββββββββ
ββββββββββββββ β
βββββββ
ββββ
βββ
β
ββββ
βββββββββββββββββββββββββββββββββββββ
ββββ
β
39.8 ΞΌs Histogram: log(frequency) by time 121 ΞΌs <
Memory estimate: 0 bytes, allocs estimate: 0.
minmax_branch
BechmarkTools.Trial: 10000 samples with 1 evaluations.
Range (min β¦ max): 4.293 ΞΌs β¦ 278.540 ΞΌs β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 33.580 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 35.297 ΞΌs Β± 9.275 ΞΌs β GC (mean Β± Ο): 0.00% Β± 0.00%
ββββ
ββββββββββββββββββββββββββββββββ
β
β
ββββββββββββββββββββββββββ β
4.29 ΞΌs Histogram: frequency by time 65.8 ΞΌs <
Memory estimate: 0 bytes, allocs estimate: 0.
minmax_short_circuit
BechmarkTools.Trial: 10000 samples with 1 evaluations.
Range (min β¦ max): 4.202 ΞΌs β¦ 269.420 ΞΌs β GC (min β¦ max): 0.00% β¦ 0.00%
Time (median): 33.655 ΞΌs β GC (median): 0.00%
Time (mean Β± Ο): 35.367 ΞΌs Β± 9.159 ΞΌs β GC (mean Β± Ο): 0.00% Β± 0.00%
ββββ
ββββββββββββββββββββββββββ
ββββββββ
ββββββββββββββββββββββββββ β
4.2 ΞΌs Histogram: frequency by time 65.3 ΞΌs <
Memory estimate: 0 bytes, allocs estimate: 0.
with
for f in [minmax_base, minmax_branch, minmax_short_circuit]
println(f, "\n")
b = @benchmark vminmax_map!($f, mi, ma, x, y) setup=begin
x, y, mi, ma = bench_setup(5_000, 2)
end evals=1
display(b)
println("\n")
end
TLDR: branching is faster. Short circuiting may or may not be slower in a real world use case?
So, I got curious about this because I was writing a function to sort three numbers. Having now tested what native instructions I get, it turns out the choice of minmax
implementation has no effect whatsoever on the final compiled code for this example.
Here are the three sorting functions
function sort3(a, b, c)
l1, h1 = minmax(a, b)
lo, h2 = minmax(l1, c)
md, hi = minmax(h1, h2)
return lo, md, hi
end
@inline minmax_nb1(x,y) = (xlty = x < y; (ifelse(xlty, x, y), ifelse(xlty, y, x)))
function sort3_nb1(a, b, c)
l1, h1 = minmax_nb1(a, b)
lo, h2 = minmax_nb1(l1, c)
md, hi = minmax_nb1(h1, h2)
return lo, md, hi
end
@inline minmax_nb2(x, y) = ifelse(isless(x, y), (x, y), (y, x))
function sort3_nb2(a, b, c)
l1, h1 = minmax_nb2(a, b)
lo, h2 = minmax_nb2(l1, c)
md, hi = minmax_nb2(h1, h2)
return lo, md, hi
end
Here are the native assembly instructions for each
julia> @code_native sort3(3, 2, 1)
.section __TEXT,__text,regular,pure_instructions
; β @ sort3.jl:10 within `sort3'
movq %rdi, %rax
; β @ sort3.jl:11 within `sort3'
; ββ @ promotion.jl:423 within `minmax'
; βββ @ int.jl:83 within `<'
cmpq %rsi, %rdx
; βββ
movq %rsi, %rdi
cmovlq %rdx, %rdi
cmovlq %rsi, %rdx
; ββ
; β @ sort3.jl:12 within `sort3'
; ββ @ promotion.jl:423 within `minmax'
; βββ @ int.jl:83 within `<'
cmpq %rcx, %rdi
; βββ
movq %rdi, %rsi
cmovgq %rcx, %rsi
cmovleq %rcx, %rdi
; ββ
; β @ sort3.jl:13 within `sort3'
; ββ @ promotion.jl:423 within `minmax'
; βββ @ int.jl:83 within `<'
cmpq %rdx, %rdi
; βββ
movq %rdx, %rcx
cmovlq %rdi, %rcx
cmovlq %rdx, %rdi
; ββ
; β @ sort3.jl:14 within `sort3'
movq %rsi, (%rax)
movq %rcx, 8(%rax)
movq %rdi, 16(%rax)
retq
nopl (%rax)
; β
julia> @code_native sort3_nb1(3, 2, 1)
.section __TEXT,__text,regular,pure_instructions
; β @ sort3.jl:19 within `sort3_nb1'
movq %rdi, %rax
; β @ sort3.jl:20 within `sort3_nb1'
; ββ @ sort3.jl:17 within `minmax_nb1'
; βββ @ int.jl:83 within `<'
cmpq %rdx, %rsi
; βββ
movq %rdx, %rdi
cmovlq %rsi, %rdi
cmovlq %rdx, %rsi
; ββ
; β @ sort3.jl:21 within `sort3_nb1'
; ββ @ sort3.jl:17 within `minmax_nb1'
; βββ @ int.jl:83 within `<'
cmpq %rcx, %rdi
; βββ
movq %rcx, %rdx
cmovlq %rdi, %rdx
cmovlq %rcx, %rdi
; ββ
; β @ sort3.jl:22 within `sort3_nb1'
; ββ @ sort3.jl:17 within `minmax_nb1'
; βββ @ int.jl:83 within `<'
cmpq %rdi, %rsi
; βββ
movq %rdi, %rcx
cmovlq %rsi, %rcx
cmovgeq %rsi, %rdi
; ββ
; β @ sort3.jl:23 within `sort3_nb1'
movq %rdx, (%rax)
movq %rcx, 8(%rax)
movq %rdi, 16(%rax)
retq
nopl (%rax)
; β
julia> @code_native sort3_nb2(3, 2, 1)
.section __TEXT,__text,regular,pure_instructions
; β @ sort3.jl:28 within `sort3_nb2'
movq %rdi, %rax
; β @ sort3.jl:29 within `sort3_nb2'
; ββ @ sort3.jl:26 within `minmax_nb2'
; βββ @ operators.jl:357 within `isless'
; ββββ @ int.jl:83 within `<'
cmpq %rdx, %rsi
; ββββ
movq %rsi, %rdi
cmovlq %rdx, %rdi
cmovlq %rsi, %rdx
; ββ
; β @ sort3.jl:30 within `sort3_nb2'
; ββ @ sort3.jl:26 within `minmax_nb2'
; βββ @ operators.jl:357 within `isless'
; ββββ @ int.jl:83 within `<'
cmpq %rcx, %rdx
; ββββ
movq %rcx, %rsi
cmovlq %rdx, %rsi
cmovlq %rcx, %rdx
; ββ
; β @ sort3.jl:31 within `sort3_nb2'
; ββ @ sort3.jl:26 within `minmax_nb2'
; βββ @ operators.jl:357 within `isless'
; ββββ @ int.jl:83 within `<'
cmpq %rdx, %rdi
; ββββ
movq %rdi, %rcx
cmovlq %rdx, %rcx
cmovlq %rdi, %rdx
; ββ
; β @ sort3.jl:32 within `sort3_nb2'
movq %rsi, (%rax)
movq %rdx, 8(%rax)
movq %rcx, 16(%rax)
retq
nopl (%rax)
; β
Thereβs plenty of use cases for a ifelse
function, because you can dispatch on it.
Except, you canβt add methods to Base.ifelse
. But IfElse.jl is used by packages like Symbolics.jl and LoopVectorization.jl.
So Iβm not a particularly big fan of Base.ifelse
in its current form.
@Luapulu @Sukera the branch version is faster not because branches are faster, but because it is easier for the compiler to understand than using ifelse
in the way that the base function does.
The compiler gets rid of the branches in optimizing it.
Note that
function minmax_nb(x::T, y::T) where {T<:AbstractFloat}
isnanx = isnan(x)
isnany = isnan(y)
ygx = y > x
sbxgsby = signbit(x) > signbit(y)
l = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, x, y))
u = ifelse(isnanx | isnany, ifelse(isnanx, x, y), ifelse(ygx | sbxgsby, y, x))
l, u
end
is just as fast as the branch versions, and much faster than base.
The compiler doesnβt really understand ifelse
ing tuples, like ifelse(b, (x,y), (y,x))
, so the way the base method is written causes the optimizer to fail.
The branch version is understood, and so is the above one without tuples in ifelse
, so both of these are optimized in the loop, and neither of them actually have branches post-optimization (the branches get converted into ifelse
statements like the above).
For reference, this is the loop body on my computer:
vector.body: ; preds = %vector.body, %vector.ph
%index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
%108 = getelementptr inbounds double, double addrspace(13)* %98, i64 %index, !dbg !102
%109 = bitcast double addrspace(13)* %108 to <8 x double> addrspace(13)*, !dbg !102
%wide.load = load <8 x double>, <8 x double> addrspace(13)* %109, align 8, !dbg !102, !tbaa !103, !alias.scope !105
%110 = getelementptr inbounds double, double addrspace(13)* %108, i64 8, !dbg !102
%111 = bitcast double addrspace(13)* %110 to <8 x double> addrspace(13)*, !dbg !102
%wide.load78 = load <8 x double>, <8 x double> addrspace(13)* %111, align 8, !dbg !102, !tbaa !103, !alias.scope !105
%112 = getelementptr inbounds double, double addrspace(13)* %101, i64 %index, !dbg !102
%113 = bitcast double addrspace(13)* %112 to <8 x double> addrspace(13)*, !dbg !102
%wide.load79 = load <8 x double>, <8 x double> addrspace(13)* %113, align 8, !dbg !102, !tbaa !103, !alias.scope !108
%114 = getelementptr inbounds double, double addrspace(13)* %112, i64 8, !dbg !102
%115 = bitcast double addrspace(13)* %114 to <8 x double> addrspace(13)*, !dbg !102
%wide.load80 = load <8 x double>, <8 x double> addrspace(13)* %115, align 8, !dbg !102, !tbaa !103, !alias.scope !108
%116 = fcmp ord <8 x double> %wide.load, %wide.load79, !dbg !110
%117 = fcmp ord <8 x double> %wide.load78, %wide.load80, !dbg !110
%118 = fcmp ord <8 x double> %wide.load, zeroinitializer, !dbg !112
%119 = fcmp ord <8 x double> %wide.load78, zeroinitializer, !dbg !112
%120 = select <8 x i1> %118, <8 x double> %wide.load79, <8 x double> %wide.load, !dbg !90
%121 = select <8 x i1> %119, <8 x double> %wide.load80, <8 x double> %wide.load78, !dbg !90
%122 = fcmp olt <8 x double> %wide.load, %wide.load79, !dbg !118
%123 = fcmp olt <8 x double> %wide.load78, %wide.load80, !dbg !118
%124 = bitcast <8 x double> %wide.load to <8 x i64>, !dbg !122
%125 = bitcast <8 x double> %wide.load78 to <8 x i64>, !dbg !122
%126 = bitcast <8 x double> %wide.load79 to <8 x i64>, !dbg !122
%127 = bitcast <8 x double> %wide.load80 to <8 x i64>, !dbg !122
%128 = icmp sgt <8 x i64> %126, <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, !dbg !125
%129 = icmp sgt <8 x i64> %127, <i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1, i64 -1>, !dbg !125
%130 = icmp slt <8 x i64> %124, zeroinitializer, !dbg !129
%131 = icmp slt <8 x i64> %125, zeroinitializer, !dbg !129
%132 = and <8 x i1> %130, %128, !dbg !129
%133 = and <8 x i1> %131, %129, !dbg !129
%134 = or <8 x i1> %122, %132, !dbg !130
%135 = or <8 x i1> %123, %133, !dbg !130
%136 = select <8 x i1> %134, <8 x double> %wide.load, <8 x double> %wide.load79, !dbg !90
%137 = select <8 x i1> %135, <8 x double> %wide.load78, <8 x double> %wide.load80, !dbg !90
%138 = select <8 x i1> %134, <8 x double> %wide.load79, <8 x double> %wide.load, !dbg !90
%139 = select <8 x i1> %135, <8 x double> %wide.load80, <8 x double> %wide.load78, !dbg !90
%predphi = select <8 x i1> %116, <8 x double> %136, <8 x double> %120
%predphi81 = select <8 x i1> %117, <8 x double> %137, <8 x double> %121
%predphi82 = select <8 x i1> %116, <8 x double> %138, <8 x double> %120
%predphi83 = select <8 x i1> %117, <8 x double> %139, <8 x double> %121
%140 = getelementptr inbounds double, double addrspace(13)* %104, i64 %index, !dbg !131
%141 = bitcast double addrspace(13)* %140 to <8 x double> addrspace(13)*, !dbg !131
store <8 x double> %predphi, <8 x double> addrspace(13)* %141, align 8, !dbg !131, !tbaa !103, !alias.scope !132, !noalias !134
%142 = getelementptr inbounds double, double addrspace(13)* %140, i64 8, !dbg !131
%143 = bitcast double addrspace(13)* %142 to <8 x double> addrspace(13)*, !dbg !131
store <8 x double> %predphi81, <8 x double> addrspace(13)* %143, align 8, !dbg !131, !tbaa !103, !alias.scope !132, !noalias !134
%144 = getelementptr inbounds double, double addrspace(13)* %107, i64 %index, !dbg !131
%145 = bitcast double addrspace(13)* %144 to <8 x double> addrspace(13)*, !dbg !131
store <8 x double> %predphi82, <8 x double> addrspace(13)* %145, align 8, !dbg !131, !tbaa !103, !alias.scope !136, !noalias !137
%146 = getelementptr inbounds double, double addrspace(13)* %144, i64 8, !dbg !131
%147 = bitcast double addrspace(13)* %146 to <8 x double> addrspace(13)*, !dbg !131
store <8 x double> %predphi83, <8 x double> addrspace(13)* %147, align 8, !dbg !131, !tbaa !103, !alias.scope !136, !noalias !137
%index.next = add i64 %index, 16
%148 = icmp eq i64 %index.next, %n.vec
br i1 %148, label %middle.block, label %vector.body, !llvm.loop !138
Note the lack of branches (it is one solid block of instructions), and all the select
statements.
This was @anon56330260βs point: the compiler understands branches, and thus writing your code with branches is a good choice to let the compiler decide whether or not to actually use a branch.
If youβre careful to only use ifelse
with primitive types like Float64
(instead of Tuple{Float64,Float64}
), youβll be fine with it. But then multuple ifelse
statements are more verbose than ?
.
So, what should the future of Base.ifelse
look like? Ideally, it would be possible to tell the compiler more explicitly that something ought to be branchless, rather than relying on it to optimise away branches, no? Or would the plan be to get rid of ifelse
and rely on the compiler to optimise away branches if needed?
And btw, why canβt you add methods to Base.ifelse
?
Because it is a builtin
function.
julia> Base.ifelse(b::Bool, x::Tuple{Vararg{Number,K}}, y::Tuple{Vararg{Number,K}}) where {K} = map((x,y) -> ifelse(b, x, y), x, y)
ERROR: cannot add methods to a builtin function
Stacktrace:
[1] top-level scope
@ REPL[5]:1
julia> using IfElse
julia> IfElse.ifelse(b::Bool, x::Tuple{Vararg{Number,K}}, y::Tuple{Vararg{Number,K}}) where {K} = map((x,y) -> Base.ifelse(b, x, y), x, y)
The future Iβd like to see is (a) make it generic, and (b) add a method like the above to workaround the performance issue discussed in this thread and that I linked to earlier.
Yes, thatβs why we have SIMD library. You can hardcode SIMD code by yourself, which is not a good idea definitely. Thatβs why we have vectorization pass in compiler to do this automatically. You might ask whether itβs possible to give some hints to compiler that the code should be branchless. But the problem is that there are many ways to be branchless. You might indicate a suboptimal solution, just like the minmax
example (itβs branchless, but not vectorized).
Instead, I think we should just define Base.ifelse
only on a constraint set of types (Int,Float64,NTuple{4,Int}β¦). Base.ifelse
acts like vifelse
in the SIMD libraries. It provides no further function than ? :
and is used only for optimization. Itβs not generic by design (it lowers to select
and not all the values can be the operands of select). Misuse of this function can hurt performance, since Juliaβs frontend compiler isnβt aware of this function (thus does no special optimization) and lowering to select
early might block further vectorization, as seen by the minmax
example.
Looking at your benchmark results, I would conclude that branch and short circuit are the same, and one would need a different reason to prefer one over the other.