Fast logsumexp

We recently held a competition among the students in a Julia course we are teaching. The goal of the competition was to profile and optimize a provided particle-filter simulation. After all the obvious things were taken care of, most of the time was spent calculating `log(sum(exp(w)))` of a weight vector `w`. Below are four attempts at this function. They all operate in place and all update both the weight vector `w` and the exponentiated weight vector `we`. The `offset` is to avoid numerical overflow/underflow. No matter how hard I tried, I couldn’t beat the library Yeppp. Not even my hand written SIMD loops (using either SIMD.jl or SIMDPirates.jl) operating on `Float32` would beat Yeppp on `Float64`

My four implementations are below. No threading is used and I do not include an implementation using SLEEF as these could not beat Yeppp either. Can someone understand how Yeppp can be so outstanding in this example?

``````using Test
N = 1000;

function logsumexp!(w,we)
offset = maximum(w)
we .= exp.(w .- offset)
s = sum(we)
w .-= log(s) + offset
we .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp!(w,we);
@test sum(we) ≈ 1
@test sum(exp.(w)) ≈ 1

# ==================================================
using Yeppp
function logsumexp_yeppp!(w,we)
offset = maximum(w)
eo     = exp(offset)
w    .-= offset
Yeppp.exp!(we,w)
s      = sum(we)
w    .-= (log(s) + 0*offset) # Offset not needed since we subtracted it above
we   .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp_yeppp!(w,we);
@test sum(we) ≈ 1
@test sum(exp.(w)) ≈ 1

# ==================================================
using SIMD
function logsumexp_simd!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
s      = zero(T)
@inbounds for i = 1:4:N
l     = VecRange{4}(i)
wel   = exp(w[l]-offset)
we[l] = wel
s    += sum(wel)
end
w .-= log(s) + offset
we .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp_simd!(w,we);
@test sum(we) ≈ 1
@test sum(exp.(w)) ≈ 1

# ==================================================
using SIMDPirates
function logsumexp_simdpirates!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
sl     = SIMDPirates.Vec{4,T}((0.,0.,0.,0.))
@inbounds @simd for i = 1:4:N
@pirate wel = exp(wl-offset)
@pirate sl += wel
SIMDPirates.vstore!(we,wel,i)
end
s    = vsum(sl)
w  .-= log(s) + offset
we .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp_simdpirates!(w,we);
@test sum(we) ≈ 1
@test sum(exp.(w)) ≈ 1

# ==================================================
using BenchmarkTools
w = randn(Float64, N); we = similar(w);
@btime logsumexp!(\$w,\$we);
@btime logsumexp_yeppp!(\$w,\$we);
@btime logsumexp_simd!(\$w,\$we);
@btime logsumexp_simdpirates!(\$w,\$we);

w = randn(Float32, N); we = similar(w);
@btime logsumexp!(\$w,\$we);
# @btime logsumexp_yeppp!(\$w,\$we); # Yeppp dpes not handle Float32
@btime logsumexp_simd!(\$w,\$we);
@btime logsumexp_simdpirates!(\$w,\$we);

``````

Benchmark results

``````julia> w = randn(Float64, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
15.947 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_yeppp!(\$w,\$we);
4.580 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_simd!(\$w,\$we);
21.399 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_simdpirates!(\$w,\$we);
21.384 μs (0 allocations: 0 bytes)
julia> w = randn(Float32, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
12.603 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_simd!(\$w,\$we);
8.965 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_simdpirates!(\$w,\$we);
8.779 μs (0 allocations: 0 bytes)
``````
9 Likes

From what I can see, even though the llvm intrinsic `@llvm.exp.v4f64` is output by SIMD.jl, in the end, LLVM still generates four separate calls:

``````julia> @code_native exp(a)
.text
...
vzeroupper
callq   *%rdi  <---------------
vmovapd %xmm0, %xmm7
vpermilpd       \$1, %xmm6, %xmm0 # xmm0 = xmm6[1,0]
callq   *%rdi  <---------------
vunpcklpd       %xmm0, %xmm7, %xmm7 # xmm7 = xmm7[0],xmm0[0]
vmovaps 32(%rsp), %ymm0
vzeroupper
callq   *%rdi  <---------------
vmovapd %xmm0, %xmm6
vpermilpd       \$1, 32(%rsp), %xmm0 # xmm0 = mem[1,0]
callq   *%rdi    <---------------

...
``````

while Yeppp have optimized versions that computes `exp` on a whole SIMD vector at once.

Unrelated, but cool to see a Julia course in Sweden!

3 Likes

Haha, so much for my manual SIMD efforts Thanks for the explanation, I would not have figured that out. Do you have a feeling for where the issue lies? Is it SIMD.jl, Julia or LLVM that is to blame? @Elrod, thanks for you efforts on SIMDPirates, do you perhaps have any insight into this?

As for the course, we gave it the first time back in 2015 during the v0.3/0.4 time, but many new PhD students who need to transition away from matlab have started since. Looking back at the course material, it’s nice to see how Julia has developed both in terms of performance and design.

LLVM. GCC will vectorize log/exp/sin/etc with the appropriate optimization flags, but LLVM needs those in addition to `-fveclib=SVML` or some other vector library.

``````
using SIMDPirates, SLEEFPirates, LoopVectorization
function logsumexp_simdpirates!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
sl     = SIMDPirates.Vec{4,T}((0.,0.,0.,0.))
@inbounds @simd for i = 1:4:N
@pirate wel = SLEEFPirates.exp(wl-offset)
@pirate sl += wel
SIMDPirates.vstore!(we,wel,i)
end
s    = vsum(sl)
w  .-= log(s) + offset
we .*= 1/s
end

function logsumexp_loopvectorization!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
s = zero(T)
@vectorize for i = 1:N
wl = w[i]
wel = exp(wl-offset)
we[i] = wel
s += wel
end
w  .-= log(s) + offset
we .*= 1/s
end

function logsumexp_sleefpirates!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
s = zero(T)
@inbounds @simd for i = 1:N
wl = w[i]
wel = SLEEFPirates.exp(wl-offset)
we[i] = wel
s += wel
end
w  .-= log(s) + offset
we .*= 1/s
end
``````

A few results:

``````julia> @btime logsumexp!(\$w,\$we);
7.844 μs (0 allocations: 0 bytes)

julia> @btime logsumexp_yeppp!(\$w,\$we);
9.393 μs (0 allocations: 0 bytes)

julia> @btime logsumexp_simdpirates!(\$w,\$we);
2.576 μs (0 allocations: 0 bytes)

julia> @btime logsumexp_loopvectorization!(\$w,\$we);
1.825 μs (0 allocations: 0 bytes)

julia> @btime logsumexp_sleefpirates!(\$w,\$we);
1.705 μs (0 allocations: 0 bytes)

julia> @code_native logsumexp_sleefpirates!(w,we)
.text
; ┌ @ REPL[43]:2 within `logsumexp_sleefpirates!'
pushq	%r15
pushq	%r14
pushq	%rbx
subq	\$224, %rsp
movq	%rsi, 56(%rsp)
movq	(%rsi), %r15
movq	8(%rsi), %r14
; │┌ @ reducedim.jl:652 within `maximum'
; ││┌ @ reducedim.jl:652 within `#maximum#562'
; │││┌ @ reducedim.jl:656 within `_maximum' @ reducedim.jl:657
; ││││┌ @ reducedim.jl:307 within `mapreduce'
; │││││┌ @ reducedim.jl:307 within `#mapreduce#555'
; ││││││┌ @ reducedim.jl:312 within `_mapreduce_dim'
movabsq	\$"size;", %rax
movq	%r15, %rdi
callq	*%rax
; │└└└└└└
; │ @ REPL[43]:3 within `logsumexp_sleefpirates!'
; │┌ @ array.jl:200 within `length'
movq	8(%r15), %rax
; │└
; │ @ REPL[43]:5 within `logsumexp_sleefpirates!'
; │┌ @ simdloop.jl:69 within `macro expansion'
; ││┌ @ range.jl:5 within `Colon'
; │││┌ @ range.jl:275 within `Type'
; ││││┌ @ range.jl:280 within `unitrange_last'
movq	%rax, %rcx
sarq	\$63, %rcx
andnq	%rax, %rcx, %r8
; │└└└└
; │┌ @ checked.jl:194 within `macro expansion'
leaq	-1(%r8), %rsi
; │└
; │┌ @ simdloop.jl:71 within `macro expansion'
; ││┌ @ simdloop.jl:51 within `simd_inner_length'
; │││┌ @ range.jl:541 within `length'
; ││││┌ @ checked.jl:165 within `checked_add'
; │││││┌ @ checked.jl:132 within `add_with_overflow'
movq	%rsi, %r9
incq	%r9
; │││││└
; │││││ @ checked.jl:166 within `checked_add'
jo	L2278
; │└└└└
; │┌ @ int.jl:49 within `macro expansion'
testq	%r9, %r9
vmovapd	%xmm0, 32(%rsp)
; └└
; ┌ @ simdloop.jl:72 within `logsumexp_sleefpirates!'
jle	L107
movq	(%r15), %rcx
movq	(%r14), %rdx
vxorpd	%xmm19, %xmm19, %xmm19
; │ @ simdloop.jl:75 within `logsumexp_sleefpirates!'
cmpq	\$32, %r8
jae	L118
xorl	%esi, %esi
jmp	L1345
L107:
vxorpd	%xmm19, %xmm19, %xmm19
jmp	L1757
; │ @ simdloop.jl:75 within `logsumexp_sleefpirates!'
L118:
leaq	(%rcx,%r8,8), %rsi
cmpq	%rsi, %rdx
jae	L143
leaq	(%rdx,%r8,8), %rsi
; │ @ simdloop.jl:75 within `logsumexp_sleefpirates!'
cmpq	%rsi, %rcx
jae	L143
xorl	%esi, %esi
jmp	L1345
; │ @ simdloop.jl:75 within `logsumexp_sleefpirates!'
L143:
movabsq	\$9223372036854775776, %rsi # imm = 0x7FFFFFFFFFFFFFE0
andq	%r8, %rsi
vmovupd	%zmm0, 64(%rsp)
vxorpd	%xmm0, %xmm0, %xmm0
xorl	%edi, %edi
movabsq	\$139756740332128, %rbx  # imm = 0x7F1BA6DCCA60
vmovups	%zmm1, 128(%rsp)
movabsq	\$139756740332136, %rbx  # imm = 0x7F1BA6DCCA68
movabsq	\$139756740332144, %rbx  # imm = 0x7F1BA6DCCA70
movabsq	\$139756740332152, %rbx  # imm = 0x7F1BA6DCCA78
movabsq	\$139756740332160, %rbx  # imm = 0x7F1BA6DCCA80
movabsq	\$139756740332168, %rbx  # imm = 0x7F1BA6DCCA88
movabsq	\$139756740332176, %rbx  # imm = 0x7F1BA6DCCA90
movabsq	\$139756740332184, %rbx  # imm = 0x7F1BA6DCCA98
movabsq	\$139756740332192, %rbx  # imm = 0x7F1BA6DCCAA0
movabsq	\$139756740332200, %rbx  # imm = 0x7F1BA6DCCAA8
movabsq	\$139756740332208, %rbx  # imm = 0x7F1BA6DCCAB0
movabsq	\$139756740332216, %rbx  # imm = 0x7F1BA6DCCAB8
movabsq	\$139756740332224, %rbx  # imm = 0x7F1BA6DCCAC0
movabsq	\$139756740332232, %rbx  # imm = 0x7F1BA6DCCAC8
movabsq	\$139756740332240, %rbx  # imm = 0x7F1BA6DCCAD0
vxorpd	%xmm18, %xmm18, %xmm18
vxorpd	%xmm19, %xmm19, %xmm19
vxorpd	%xmm20, %xmm20, %xmm20
; └
; ┌ @ REPL[43]:5 within `logsumexp_sleefpirates!'
; │┌ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:6
; ││┌ @ array.jl:728 within `getindex'
L448:
vmovupd	(%rcx,%rdi,8), %zmm5
vmovupd	64(%rcx,%rdi,8), %zmm21
vmovupd	128(%rcx,%rdi,8), %zmm22
vmovupd	192(%rcx,%rdi,8), %zmm23
vmovupd	64(%rsp), %zmm2
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ float.jl:397
vsubpd	%zmm2, %zmm5, %zmm5
vsubpd	%zmm2, %zmm21, %zmm29
vsubpd	%zmm2, %zmm22, %zmm30
vsubpd	%zmm2, %zmm23, %zmm31
vmovupd	128(%rsp), %zmm2
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:7
; ││┌ @ exp.jl:181 within `exp'
; │││┌ @ float.jl:399 within `*'
vmulpd	%zmm2, %zmm5, %zmm21
vmulpd	%zmm2, %zmm29, %zmm22
vmulpd	%zmm2, %zmm30, %zmm23
vmulpd	%zmm2, %zmm31, %zmm24
; ││└└
; ││┌ @ float.jl:370 within `exp'
vrndscalepd	\$4, %zmm21, %zmm25
vrndscalepd	\$4, %zmm22, %zmm26
vrndscalepd	\$4, %zmm23, %zmm27
vrndscalepd	\$4, %zmm24, %zmm28
; ││└
; ││┌ @ exp.jl:183 within `exp'
; │││┌ @ float.jl:304 within `unsafe_trunc'
vcvttpd2qq	%zmm25, %zmm21
vcvttpd2qq	%zmm26, %zmm22
vcvttpd2qq	%zmm27, %zmm23
vcvttpd2qq	%zmm28, %zmm24
; ││└└
; ││┌ @ float.jl:404 within `exp'
; ││└
; ││┌ @ exp.jl:186 within `exp'
; │││┌ @ float.jl:404 within `muladd'
; │││└
; │││ @ exp.jl:188 within `exp'
; │││┌ @ exp.jl:161 within `exp_kernel'
; ││││┌ @ math.jl:101 within `macro expansion'
; │││││┌ @ float.jl:404 within `muladd'
vmovapd	%zmm1, %zmm29
vmovapd	%zmm1, %zmm30
vmovapd	%zmm1, %zmm31
vmovapd	%zmm1, %zmm5
; ││└└└└
; ││┌ @ float.jl:399 within `exp'
vmulpd	%zmm25, %zmm25, %zmm2
vmulpd	%zmm29, %zmm2, %zmm2
vmulpd	%zmm26, %zmm26, %zmm29
vmulpd	%zmm30, %zmm29, %zmm29
vmulpd	%zmm27, %zmm27, %zmm30
vmulpd	%zmm31, %zmm30, %zmm30
vmulpd	%zmm28, %zmm28, %zmm31
vmulpd	%zmm5, %zmm31, %zmm5
; ││└
; ││┌ @ exp.jl:189 within `exp'
; │││┌ @ operators.jl:529 within `+' @ float.jl:395
; │││└
; │││┌ @ float.jl:395 within `+'
; │││└
; │││ @ exp.jl:190 within `exp'
; │││┌ @ priv.jl:51 within `ldexp2k'
; ││││┌ @ int.jl:444 within `>>' @ int.jl:437
vpsraq	\$1, %zmm21, %zmm27
vpsraq	\$1, %zmm22, %zmm28
vpsraq	\$1, %zmm23, %zmm29
vpsraq	\$1, %zmm24, %zmm30
; ││││└
; ││││ @ priv.jl:52 within `ldexp2k'
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm27, %zmm31
; │││└└└└
; │││┌ @ float.jl:399 within `ldexp2k'
vmulpd	%zmm31, %zmm2, %zmm2
; │││└
; │││┌ @ priv.jl:52 within `ldexp2k'
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm28, %zmm31
; ││││└└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulpd	%zmm31, %zmm25, %zmm25
; ││││└└
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm29, %zmm31
; ││││└└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulpd	%zmm31, %zmm26, %zmm26
; ││││└└
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm30, %zmm31
; ││││└└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulpd	%zmm31, %zmm5, %zmm5
; ││││└└
; ││││┌ @ int.jl:52 within `-'
vpsubq	%zmm27, %zmm21, %zmm21
vpsubq	%zmm28, %zmm22, %zmm22
vpsubq	%zmm29, %zmm23, %zmm23
vpsubq	%zmm30, %zmm24, %zmm24
; │││└└
; │││┌ @ int.jl:439 within `ldexp2k'
vpsllq	\$52, %zmm21, %zmm21
; │││└
; │││┌ @ priv.jl:52 within `ldexp2k'
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulpd	%zmm21, %zmm2, %zmm2
; ││││└└
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm22, %zmm21
; ││││└└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulpd	%zmm21, %zmm25, %zmm21
; ││││└└
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm23, %zmm22
; ││││└└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulpd	%zmm22, %zmm26, %zmm22
; ││││└└
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm24, %zmm23
; ││││└└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulpd	%zmm23, %zmm5, %zmm5
; ││└└└└
; ││ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovupd	%zmm2, (%rdx,%rdi,8)
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; ││┌ @ float.jl:395 within `+'
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; ││┌ @ array.jl:766 within `setindex!'
vmovupd	%zmm21, 64(%rdx,%rdi,8)
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; ││┌ @ float.jl:395 within `+'
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; ││┌ @ array.jl:766 within `setindex!'
vmovupd	%zmm22, 128(%rdx,%rdi,8)
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; ││┌ @ float.jl:395 within `+'
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; ││┌ @ array.jl:766 within `setindex!'
vmovupd	%zmm5, 192(%rdx,%rdi,8)
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; ││┌ @ float.jl:395 within `+'
; │└└
; │┌ @ int.jl:53 within `macro expansion'
cmpq	%rdi, %rsi
jne	L448
; │└
; │┌ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; ││┌ @ float.jl:395 within `+'
vextractf64x4	\$1, %zmm0, %ymm1
vextractf128	\$1, %ymm0, %xmm1
vpermilpd	\$1, %xmm0, %xmm1 # xmm1 = xmm0[1,0]
cmpq	%rsi, %r8
vmovapd	32(%rsp), %xmm0
; └└└
; ┌ @ simdloop.jl:75 within `logsumexp_sleefpirates!'
je	L1757
L1345:
movabsq	\$4607182418800017408, %rdi # imm = 0x3FF0000000000000
movabsq	\$139756740332128, %rbx  # imm = 0x7F1BA6DCCA60
vmovsd	(%rbx), %xmm8           # xmm8 = mem[0],zero
movabsq	\$139756740332248, %rbx  # imm = 0x7F1BA6DCCAD8
vmovsd	(%rbx), %xmm9           # xmm9 = mem[0],zero
movabsq	\$139756740332256, %rbx  # imm = 0x7F1BA6DCCAE0
vmovsd	(%rbx), %xmm10          # xmm10 = mem[0],zero
movabsq	\$139756740332152, %rbx  # imm = 0x7F1BA6DCCA78
vmovsd	(%rbx), %xmm11          # xmm11 = mem[0],zero
movabsq	\$139756740332160, %rbx  # imm = 0x7F1BA6DCCA80
vmovsd	(%rbx), %xmm12          # xmm12 = mem[0],zero
movabsq	\$139756740332168, %rbx  # imm = 0x7F1BA6DCCA88
vmovsd	(%rbx), %xmm13          # xmm13 = mem[0],zero
movabsq	\$139756740332176, %rbx  # imm = 0x7F1BA6DCCA90
vmovsd	(%rbx), %xmm14          # xmm14 = mem[0],zero
movabsq	\$139756740332184, %rbx  # imm = 0x7F1BA6DCCA98
vmovsd	(%rbx), %xmm15          # xmm15 = mem[0],zero
movabsq	\$139756740332192, %rbx  # imm = 0x7F1BA6DCCAA0
vmovsd	(%rbx), %xmm16          # xmm16 = mem[0],zero
movabsq	\$139756740332200, %rbx  # imm = 0x7F1BA6DCCAA8
vmovsd	(%rbx), %xmm17          # xmm17 = mem[0],zero
movabsq	\$139756740332208, %rbx  # imm = 0x7F1BA6DCCAB0
vmovsd	(%rbx), %xmm18          # xmm18 = mem[0],zero
movabsq	\$139756740332216, %rbx  # imm = 0x7F1BA6DCCAB8
vmovsd	(%rbx), %xmm3           # xmm3 = mem[0],zero
movabsq	\$139756740332224, %rbx  # imm = 0x7F1BA6DCCAC0
vmovsd	(%rbx), %xmm4           # xmm4 = mem[0],zero
movabsq	\$139756740332232, %rbx  # imm = 0x7F1BA6DCCAC8
vmovsd	(%rbx), %xmm5           # xmm5 = mem[0],zero
movabsq	\$139756740332240, %rbx  # imm = 0x7F1BA6DCCAD0
vmovsd	(%rbx), %xmm6           # xmm6 = mem[0],zero
nopw	%cs:(%rax,%rax)
; └
; ┌ @ REPL[43]:5 within `logsumexp_sleefpirates!'
; │┌ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:6
; ││┌ @ array.jl:728 within `getindex'
L1584:
vmovsd	(%rcx,%rsi,8), %xmm7    # xmm7 = mem[0],zero
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ float.jl:397
vsubsd	%xmm0, %xmm7, %xmm7
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:7
; ││┌ @ exp.jl:181 within `exp'
; │││┌ @ float.jl:399 within `*'
vmulsd	%xmm8, %xmm7, %xmm1
; │││└
; │││┌ @ floatfuncs.jl:129 within `round' @ float.jl:370
vrndscalesd	\$4, %xmm1, %xmm1, %xmm1
; ││└└
; ││┌ @ float.jl:304 within `exp'
vcvttsd2si	%xmm1, %rax
; ││└
; ││┌ @ exp.jl:185 within `exp'
; │││┌ @ float.jl:404 within `muladd'
; │││└
; │││ @ exp.jl:186 within `exp'
; │││┌ @ float.jl:404 within `muladd'
; │││└
; │││ @ exp.jl:188 within `exp'
; │││┌ @ exp.jl:161 within `exp_kernel'
; ││││┌ @ math.jl:101 within `macro expansion'
; │││││┌ @ float.jl:404 within `muladd'
vmovapd	%xmm11, %xmm2
; ││└└└└
; ││┌ @ float.jl:399 within `exp'
vmulsd	%xmm7, %xmm7, %xmm1
vmulsd	%xmm2, %xmm1, %xmm1
; ││└
; ││┌ @ exp.jl:189 within `exp'
; │││┌ @ operators.jl:529 within `+' @ float.jl:395
; │││└
; │││┌ @ float.jl:395 within `+'
; ││└└
; ││┌ @ int.jl:437 within `exp'
movq	%rax, %rbx
sarq	%rbx
; ││└
; ││┌ @ exp.jl:190 within `exp'
; │││┌ @ int.jl:52 within `ldexp2k'
subl	%ebx, %eax
; │││└
; │││┌ @ priv.jl:52 within `ldexp2k'
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
shlq	\$52, %rbx
; │││││└└
; │││││┌ @ essentials.jl:417 within `integer2float'
vmovq	%rbx, %xmm1
; ││││└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulsd	%xmm1, %xmm2, %xmm2
; ││││└└
; ││││┌ @ utils.jl:51 within `pow2i'
; │││││┌ @ utils.jl:20 within `integer2float'
; ││││││┌ @ int.jl:446 within `<<' @ int.jl:439
shlq	\$52, %rax
; ││││││└
; ││││││┌ @ essentials.jl:417 within `reinterpret'
vmovq	%rax, %xmm1
; ││││└└└
; ││││┌ @ floating_point_arithmetic.jl:62 within `evmul'
; │││││┌ @ float.jl:399 within `*'
vmulsd	%xmm1, %xmm2, %xmm1
; ││└└└└
; ││ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; ││┌ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rdx,%rsi,8)
; ││└
; ││ @ simdloop.jl:77 within `macro expansion' @ float.jl:395
; ││ @ simdloop.jl:78 within `macro expansion'
; ││┌ @ int.jl:53 within `+'
; ││└
; ││ @ simdloop.jl:75 within `macro expansion'
; ││┌ @ int.jl:49 within `<'
cmpq	%r9, %rsi
; ││└
jb	L1584
; │└
; │ @ REPL[43]:11 within `logsumexp_sleefpirates!'
L1757:
movabsq	\$"<;", %rax
vmovupd	%zmm19, 64(%rsp)
vmovapd	%xmm19, %xmm0
vzeroupper
callq	*%rax
; │┌ @ broadcast.jl:801 within `materialize!'
; ││┌ @ abstractarray.jl:75 within `axes'
; │││┌ @ array.jl:155 within `size'
movq	24(%r15), %rdx
; ││└└
; ││┌ @ promotion.jl:414 within `axes'
testq	%rdx, %rdx
; ││└
; │││┌ @ simdloop.jl:72 within `macro expansion'
jle	L2039
; └└└└
; ┌ @ simdloop.jl within `logsumexp_sleefpirates!'
movq	%rdx, %rax
sarq	\$63, %rax
andnq	%rdx, %rax, %rax
; └
; ┌ @ REPL[43]:11 within `logsumexp_sleefpirates!'
; │┌ @ float.jl:395 within `+'
movq	(%r15), %rcx
; │└
; │┌ @ broadcast.jl:801 within `materialize!'
; │││┌ @ broadcast.jl:869 within `preprocess'
; ││││┌ @ broadcast.jl:872 within `preprocess_args'
; │││││┌ @ broadcast.jl:870 within `preprocess'
; ││││││┌ @ broadcast.jl:592 within `extrude'
; │││││││┌ @ broadcast.jl:547 within `newindexer'
; ││││││││┌ @ broadcast.jl:548 within `shapeindexer'
; │││││││││┌ @ broadcast.jl:553 within `_newindexer'
; ││││││││││┌ @ operators.jl:193 within `!='
; │││││││││││┌ @ promotion.jl:403 within `=='
cmpq	\$1, %rdx
; │││└└└└└└└└└
; │││┌ @ simdloop.jl:75 within `macro expansion'
jne	L1867
xorl	%edx, %edx
nopw	%cs:(%rax,%rax)
; ││││ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; ││││┌ @ broadcast.jl:558 within `getindex'
; ││││││┌ @ broadcast.jl:621 within `_getindex'
; ││││││││┌ @ array.jl:728 within `getindex'
L1840:
vmovsd	(%rcx), %xmm1           # xmm1 = mem[0],zero
; ││││││└└└
; │││││││┌ @ float.jl:397 within `-'
vsubsd	%xmm0, %xmm1, %xmm1
; ││││└└└└
; ││││┌ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rcx,%rdx,8)
; │││└└
; │││┌ @ int.jl:53 within `macro expansion'
; │││└
; │││┌ @ simdloop.jl:75 within `macro expansion'
; ││││┌ @ int.jl:49 within `<'
cmpq	%rax, %rdx
; ││││└
jb	L1840
jmp	L2039
L1867:
cmpq	\$32, %rax
jae	L1880
xorl	%edx, %edx
jmp	L2016
; ││││ @ simdloop.jl:75 within `macro expansion'
L1880:
movabsq	\$9223372036854775776, %rdx # imm = 0x7FFFFFFFFFFFFFE0
andq	%rax, %rdx
leaq	192(%rcx), %rsi
; ││││ @ simdloop.jl:78 within `macro expansion'
; ││││┌ @ int.jl:53 within `+'
movq	%rdx, %rdi
nopw	%cs:(%rax,%rax)
; ││││└
; ││││ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; ││││┌ @ broadcast.jl:558 within `getindex'
; ││││││┌ @ broadcast.jl:621 within `_getindex'
; ││││││││┌ @ array.jl:728 within `getindex'
L1920:
vmovupd	-192(%rsi), %zmm2
vmovupd	-128(%rsi), %zmm3
vmovupd	-64(%rsi), %zmm4
vmovupd	(%rsi), %zmm5
; ││││││└└└
; │││││││┌ @ float.jl:397 within `-'
vsubpd	%zmm1, %zmm2, %zmm2
vsubpd	%zmm1, %zmm3, %zmm3
vsubpd	%zmm1, %zmm4, %zmm4
vsubpd	%zmm1, %zmm5, %zmm5
; ││││└└└└
; ││││ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovupd	%zmm2, -192(%rsi)
vmovupd	%zmm3, -128(%rsi)
vmovupd	%zmm4, -64(%rsi)
vmovupd	%zmm5, (%rsi)
; ││││ @ simdloop.jl:78 within `macro expansion'
; ││││┌ @ int.jl:53 within `+'
addq	\$256, %rsi              # imm = 0x100
jne	L1920
; └└└└└
; ┌ @ int.jl within `logsumexp_sleefpirates!'
cmpq	%rdx, %rax
; └
; ┌ @ simdloop.jl:75 within `logsumexp_sleefpirates!'
je	L2039
; └
; ┌ @ REPL[43]:11 within `logsumexp_sleefpirates!'
; │┌ @ broadcast.jl:801 within `materialize!'
; │││┌ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; ││││┌ @ broadcast.jl:558 within `getindex'
; ││││││┌ @ broadcast.jl:621 within `_getindex'
; ││││││││┌ @ array.jl:728 within `getindex'
L2016:
vmovsd	(%rcx,%rdx,8), %xmm1    # xmm1 = mem[0],zero
; │││││└└└└
; │││││┌ @ float.jl:397 within `_broadcast_getindex'
vsubsd	%xmm0, %xmm1, %xmm1
; ││││└└
; ││││┌ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rcx,%rdx,8)
; ││││└
; ││││ @ simdloop.jl:78 within `macro expansion'
; ││││┌ @ int.jl:53 within `+'
; ││││└
; ││││ @ simdloop.jl:75 within `macro expansion'
; ││││┌ @ int.jl:49 within `<'
cmpq	%rax, %rdx
jb	L2016
; │└└└└
; │ @ REPL[43]:12 within `logsumexp_sleefpirates!'
; │┌ @ broadcast.jl:801 within `materialize!'
; ││┌ @ abstractarray.jl:75 within `axes'
; │││┌ @ array.jl:155 within `size'
L2039:
movq	24(%r14), %rdx
; ││└└
; ││┌ @ promotion.jl:414 within `axes'
testq	%rdx, %rdx
; ││└
; │││┌ @ simdloop.jl:72 within `macro expansion'
jle	L2259
; └└└└
; ┌ @ simdloop.jl within `logsumexp_sleefpirates!'
movq	%rdx, %rax
sarq	\$63, %rax
andnq	%rdx, %rax, %rax
movabsq	\$139756740332240, %rcx  # imm = 0x7F1BA6DCCAD0
; └
; ┌ @ REPL[43]:12 within `logsumexp_sleefpirates!'
; │┌ @ promotion.jl:316 within `/' @ float.jl:401
vmovsd	(%rcx), %xmm0           # xmm0 = mem[0],zero
vdivsd	64(%rsp), %xmm0, %xmm0
movq	(%r14), %rcx
; │└
; │┌ @ broadcast.jl:801 within `materialize!'
; │││┌ @ broadcast.jl:869 within `preprocess'
; ││││┌ @ broadcast.jl:872 within `preprocess_args'
; │││││┌ @ broadcast.jl:870 within `preprocess'
; ││││││┌ @ broadcast.jl:592 within `extrude'
; │││││││┌ @ broadcast.jl:547 within `newindexer'
; ││││││││┌ @ broadcast.jl:548 within `shapeindexer'
; │││││││││┌ @ broadcast.jl:553 within `_newindexer'
; ││││││││││┌ @ operators.jl:193 within `!='
; │││││││││││┌ @ promotion.jl:403 within `=='
cmpq	\$1, %rdx
; │││└└└└└└└└└
; │││┌ @ simdloop.jl:75 within `macro expansion'
jne	L2119
xorl	%edx, %edx
nop
; ││││ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; ││││┌ @ broadcast.jl:558 within `getindex'
; │││││││┌ @ float.jl:399 within `*'
L2096:
vmulsd	(%rcx), %xmm0, %xmm1
; ││││└└└└
; ││││┌ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rcx,%rdx,8)
; │││└└
; │││┌ @ int.jl:53 within `macro expansion'
; │││└
; │││┌ @ simdloop.jl:75 within `macro expansion'
; ││││┌ @ int.jl:49 within `<'
cmpq	%rax, %rdx
; ││││└
jb	L2096
jmp	L2259
L2119:
cmpq	\$32, %rax
jae	L2129
xorl	%edx, %edx
jmp	L2240
; ││││ @ simdloop.jl:75 within `macro expansion'
L2129:
movabsq	\$9223372036854775776, %rdx # imm = 0x7FFFFFFFFFFFFFE0
andq	%rax, %rdx
leaq	192(%rcx), %rsi
; ││││ @ simdloop.jl:78 within `macro expansion'
; ││││┌ @ int.jl:53 within `+'
movq	%rdx, %rdi
nop
; ││││└
; ││││ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; ││││┌ @ broadcast.jl:558 within `getindex'
; │││││││┌ @ float.jl:399 within `*'
L2160:
vmulpd	-192(%rsi), %zmm1, %zmm2
vmulpd	-128(%rsi), %zmm1, %zmm3
vmulpd	-64(%rsi), %zmm1, %zmm4
vmulpd	(%rsi), %zmm1, %zmm5
; ││││└└└└
; ││││ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovupd	%zmm2, -192(%rsi)
vmovupd	%zmm3, -128(%rsi)
vmovupd	%zmm4, -64(%rsi)
vmovupd	%zmm5, (%rsi)
; ││││ @ simdloop.jl:78 within `macro expansion'
; ││││┌ @ int.jl:53 within `+'
addq	\$256, %rsi              # imm = 0x100
jne	L2160
; └└└└└
; ┌ @ int.jl within `logsumexp_sleefpirates!'
cmpq	%rdx, %rax
; └
; ┌ @ simdloop.jl:75 within `logsumexp_sleefpirates!'
je	L2259
nopl	(%rax,%rax)
; └
; ┌ @ REPL[43]:12 within `logsumexp_sleefpirates!'
; │┌ @ broadcast.jl:801 within `materialize!'
; │││┌ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; ││││┌ @ broadcast.jl:558 within `getindex'
; │││││││┌ @ float.jl:399 within `*'
L2240:
vmulsd	(%rcx,%rdx,8), %xmm0, %xmm1
; ││││└└└└
; ││││ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovsd	%xmm1, (%rcx,%rdx,8)
; ││││ @ simdloop.jl:78 within `macro expansion'
; ││││┌ @ int.jl:53 within `+'
; ││││└
; ││││ @ simdloop.jl:75 within `macro expansion'
; ││││┌ @ int.jl:49 within `<'
cmpq	%rax, %rdx
jb	L2240
; │└└└└
L2259:
movq	%r14, %rax
popq	%rbx
popq	%r14
popq	%r15
vzeroupper
retq
; │ @ REPL[43]:5 within `logsumexp_sleefpirates!'
; │┌ @ simdloop.jl:71 within `macro expansion'
; ││┌ @ simdloop.jl:51 within `simd_inner_length'
; │││┌ @ range.jl:541 within `length'
; ││││┌ @ checked.jl:166 within `checked_add'
L2278:
movabsq	\$throw_overflowerr_binaryop, %rax
movabsq	\$139757138773328, %rdi  # imm = 0x7F1BBE9C8550
movl	\$1, %edx
callq	*%rax
ud2
nopw	%cs:(%rax,%rax)
; └└└└└
``````

SLEEFPirates is vectorized.

8 Likes

Thanks for this! On my machines, I still cannot beat Yeppp on `Float64`, but with SLEEFPirates and LoopVectorization I get significantly closer.

``````julia> w = randn(Float64, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
10.358 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_yeppp!(\$w,\$we);
1.996 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_sleefpirates!(\$w,\$we);
3.157 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_loopvec!(\$w,\$we);
3.575 μs (0 allocations: 0 bytes)
julia> w = randn(Float32, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
7.773 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_sleefpirates!(\$w,\$we);
1.580 μs (0 allocations: 0 bytes)
julia> @btime logsumexp_loopvec!(\$w,\$we);
ERROR: MethodError: no method matching vload(::Type{SVec{4,Float64}}, ::VectorizationBase.vpointer{Float32})
Closest candidates are:
``````
1 Like
``````@generated function logsumexp_loopvectorization!(w::Vector{T},we) where T
quote
offset = maximum(w)
N = length(w)
s = zero(T)
@vectorize \$T for i = 1:N
wl = w[i]
wel = exp(wl-offset)
we[i] = wel
s += wel
end
w  .-= log(s) + offset
we .*= 1/s
end
end
``````

`@vectorize` accepts a type argument, but it can’t be a symbol. Using a generated function to insert it works.

``````julia> @btime logsumexp_sleefpirates!(\$w,\$we);
1.474 μs (0 allocations: 0 bytes)

julia> @btime logsumexp_loopvectorization!(\$w,\$we);
1.325 μs (0 allocations: 0 bytes)
``````

Yepp’s performance is unfortunately inconsistent across architectures. It was even slower than the totally unvectorized version on my computer.
I’m guessing it’s a precompiled binary that dispatches based on recognizing the host CPU.
Before OpenBLAS had avx512 support, it just used avx2 kernels. While slower, that’s much faster than just falling back to the most generic code!

4 Likes

I should comment that the offset algorithm that you are using for logsumexp is potentially problematic — you can easily contrive a case where it is off by a factor of 2. In particular, try `[1e-20, log(1e-20)]` with your logsumexp algorithm

``````function f(x)
X = maximum(x)
return X + log(sum(exp.(x .- X)))
end
``````

You get:

``````julia> x = [1e-20, log(1e-20)];

julia> f(x)                  # inaccurate!
1.0e-20

julia> Float64(f(big.(x)))   # accurate
1.9999999999999993e-20
``````

A possible fix is to pull the maximum `x` term out of the sum and use the `log1p` function. I actually assigned this as an exam question recently, so you can see the explanation in my problem 3 solutions.

Not sure if this matters for the machine-learning application, however, since there you are adding the logsumexp to a posterior probability and so errors in tiny values like this may get rounded away in your final result.

8 Likes

Thanks for the link! I had not considered it so carefully before, but will sure make use of this trick in the future.

For future reference, here is my implementation

1 Like

It looks like you’re mostly using `logsumexp!` for sampling from a categorical, is that right? Have you tried using the Gumbel max trick for this? It’s pretty great for sampling a categorical with unnormalized log-probability weights.

1 Like

That trick does not return the log likelihood like the calculation above, right? Also, they mention that all the log(log(U)) can be precalculated, doesn’t this cause some bias?

How could one link LLVM with SVML as you suggest above?

“Log-likelihood” doesn’t make sense unless there’s context of some distribution. Do you need the log-likelihood of a Dirichlet? Then no, I don’t think the Gumbel trick would help.

I’ve seen logsumexp in a few different contexts, and I’m not sure which one you’re working in. It had seemed you were assigning a log-likelihood to each particle, and were then concerned about normalizing these in order to draw a categorical. In that case my point was that the Gumbel trick will let you skip the normalization. [“returning the log-likelihood” would be trivial in this case, so I’m guessing this isn’t what you meant]

If you re-use values, yes. But I don’t think anyone’s suggesting that. As I see it, the benefit is to do an equivalent calculation with fewer dependencies. This means you could compute the Gumbels in parallel, or have a thread dedicated to keeping a supply available, or lots of other tricks.

I’ve also wondered sometimes about things like `log(log(U))`. We have fast algorithms for lots of things. If `doublelog` is critical for performance, shouldn’t we be computing it more directly, maybe a series expansion focused on U, rather than calling `log` twice?

Finally, I think you’re doing systematic sampling, right? This seems convenient for this approach, since quantiles for a uniform are dead-simple. It could also work well to pull in something like Sobol.jl if QMC is a possibility.

Thanks for this thread. I typically use a lot of `logsumexp` in my code, but never took the time to optimize it. I will try your variants. I expect to see some performance gains!

This is exactly what I do, but I also make use of the value `sum(exp.(w))` for calculating the likelihood of the data given the model parameters, so that must be calculated anyway. If I was not interested in this number, the trick might very well have sped up the calculations.

1 Like