# Fast logsumexp

We recently held a competition among the students in a Julia course we are teaching. The goal of the competition was to profile and optimize a provided particle-filter simulation. After all the obvious things were taken care of, most of the time was spent calculating `log(sum(exp(w)))` of a weight vector `w`. Below are four attempts at this function. They all operate in place and all update both the weight vector `w` and the exponentiated weight vector `we`. The `offset` is to avoid numerical overflow/underflow. No matter how hard I tried, I couldnβt beat the library Yeppp. Not even my hand written SIMD loops (using either SIMD.jl or SIMDPirates.jl) operating on `Float32` would beat Yeppp on `Float64`

My four implementations are below. No threading is used and I do not include an implementation using SLEEF as these could not beat Yeppp either. Can someone understand how Yeppp can be so outstanding in this example?

``````using Test
N = 1000;

function logsumexp!(w,we)
offset = maximum(w)
we .= exp.(w .- offset)
s = sum(we)
w .-= log(s) + offset
we .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp!(w,we);
@test sum(we) β 1
@test sum(exp.(w)) β 1

# ==================================================
using Yeppp
function logsumexp_yeppp!(w,we)
offset = maximum(w)
eo     = exp(offset)
w    .-= offset
Yeppp.exp!(we,w)
s      = sum(we)
w    .-= (log(s) + 0*offset) # Offset not needed since we subtracted it above
we   .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp_yeppp!(w,we);
@test sum(we) β 1
@test sum(exp.(w)) β 1

# ==================================================
using SIMD
function logsumexp_simd!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
s      = zero(T)
@inbounds for i = 1:4:N
l     = VecRange{4}(i)
wel   = exp(w[l]-offset)
we[l] = wel
s    += sum(wel)
end
w .-= log(s) + offset
we .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp_simd!(w,we);
@test sum(we) β 1
@test sum(exp.(w)) β 1

# ==================================================
# ] add https://github.com/chriselrod/VectorizationBase.jl
# ] add https://github.com/chriselrod/SIMDPirates.jl
using SIMDPirates
function logsumexp_simdpirates!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
sl     = SIMDPirates.Vec{4,T}((0.,0.,0.,0.))
@inbounds @simd for i = 1:4:N
wl = SIMDPirates.vload(SIMDPirates.Vec{4,T}, w, i)
@pirate wel = exp(wl-offset)
@pirate sl += wel
SIMDPirates.vstore!(we,wel,i)
end
s    = vsum(sl)
w  .-= log(s) + offset
we .*= 1/s
end

w = randn(Float64, N); we = similar(w);
logsumexp_simdpirates!(w,we);
@test sum(we) β 1
@test sum(exp.(w)) β 1

# ==================================================
using BenchmarkTools
w = randn(Float64, N); we = similar(w);
@btime logsumexp!(\$w,\$we);
@btime logsumexp_yeppp!(\$w,\$we);
@btime logsumexp_simd!(\$w,\$we);
@btime logsumexp_simdpirates!(\$w,\$we);

w = randn(Float32, N); we = similar(w);
@btime logsumexp!(\$w,\$we);
# @btime logsumexp_yeppp!(\$w,\$we); # Yeppp dpes not handle Float32
@btime logsumexp_simd!(\$w,\$we);
@btime logsumexp_simdpirates!(\$w,\$we);

``````

Benchmark results

``````julia> w = randn(Float64, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
15.947 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_yeppp!(\$w,\$we);
4.580 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_simd!(\$w,\$we);
21.399 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_simdpirates!(\$w,\$we);
21.384 ΞΌs (0 allocations: 0 bytes)
julia> w = randn(Float32, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
12.603 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_simd!(\$w,\$we);
8.965 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_simdpirates!(\$w,\$we);
8.779 ΞΌs (0 allocations: 0 bytes)
``````
9 Likes

From what I can see, even though the llvm intrinsic `@llvm.exp.v4f64` is output by SIMD.jl, in the end, LLVM still generates four separate calls:

``````julia> @code_native exp(a)
.text
...
vzeroupper
callq   *%rdi  <---------------
vmovapd %xmm0, %xmm7
vpermilpd       \$1, %xmm6, %xmm0 # xmm0 = xmm6[1,0]
callq   *%rdi  <---------------
vunpcklpd       %xmm0, %xmm7, %xmm7 # xmm7 = xmm7[0],xmm0[0]
vmovaps 32(%rsp), %ymm0
vzeroupper
callq   *%rdi  <---------------
vmovapd %xmm0, %xmm6
vpermilpd       \$1, 32(%rsp), %xmm0 # xmm0 = mem[1,0]
callq   *%rdi    <---------------

...
``````

while Yeppp have optimized versions that computes `exp` on a whole SIMD vector at once.

Unrelated, but cool to see a Julia course in Sweden!

3 Likes

Haha, so much for my manual SIMD efforts Thanks for the explanation, I would not have figured that out. Do you have a feeling for where the issue lies? Is it SIMD.jl, Julia or LLVM that is to blame? @Elrod, thanks for you efforts on SIMDPirates, do you perhaps have any insight into this?

As for the course, we gave it the first time back in 2015 during the v0.3/0.4 time, but many new PhD students who need to transition away from matlab have started since. Looking back at the course material, itβs nice to see how Julia has developed both in terms of performance and design.

LLVM. GCC will vectorize log/exp/sin/etc with the appropriate optimization flags, but LLVM needs those in addition to `-fveclib=SVML` or some other vector library.

``````
using SIMDPirates, SLEEFPirates, LoopVectorization
function logsumexp_simdpirates!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
sl     = SIMDPirates.Vec{4,T}((0.,0.,0.,0.))
@inbounds @simd for i = 1:4:N
wl = SIMDPirates.vload(SIMDPirates.Vec{4,T}, w, i)
@pirate wel = SLEEFPirates.exp(wl-offset)
@pirate sl += wel
SIMDPirates.vstore!(we,wel,i)
end
s    = vsum(sl)
w  .-= log(s) + offset
we .*= 1/s
end

function logsumexp_loopvectorization!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
s = zero(T)
@vectorize for i = 1:N
wl = w[i]
wel = exp(wl-offset)
we[i] = wel
s += wel
end
w  .-= log(s) + offset
we .*= 1/s
end

function logsumexp_sleefpirates!(w::Vector{T},we) where T
offset = maximum(w)
N      = length(w)
s = zero(T)
@inbounds @simd for i = 1:N
wl = w[i]
wel = SLEEFPirates.exp(wl-offset)
we[i] = wel
s += wel
end
w  .-= log(s) + offset
we .*= 1/s
end
``````

A few results:

``````julia> @btime logsumexp!(\$w,\$we);
7.844 ΞΌs (0 allocations: 0 bytes)

julia> @btime logsumexp_yeppp!(\$w,\$we);
9.393 ΞΌs (0 allocations: 0 bytes)

julia> @btime logsumexp_simdpirates!(\$w,\$we);
2.576 ΞΌs (0 allocations: 0 bytes)

julia> @btime logsumexp_loopvectorization!(\$w,\$we);
1.825 ΞΌs (0 allocations: 0 bytes)

julia> @btime logsumexp_sleefpirates!(\$w,\$we);
1.705 ΞΌs (0 allocations: 0 bytes)

julia> @code_native logsumexp_sleefpirates!(w,we)
.text
; β @ REPL[43]:2 within `logsumexp_sleefpirates!'
pushq	%r15
pushq	%r14
pushq	%rbx
subq	\$224, %rsp
movq	%rsi, 56(%rsp)
movq	(%rsi), %r15
movq	8(%rsi), %r14
; ββ @ reducedim.jl:652 within `maximum'
; βββ @ reducedim.jl:652 within `#maximum#562'
; ββββ @ reducedim.jl:656 within `_maximum' @ reducedim.jl:657
; βββββ @ reducedim.jl:307 within `mapreduce'
; ββββββ @ reducedim.jl:307 within `#mapreduce#555'
; βββββββ @ reducedim.jl:312 within `_mapreduce_dim'
movabsq	\$"size;", %rax
movq	%r15, %rdi
callq	*%rax
; βββββββ
; β @ REPL[43]:3 within `logsumexp_sleefpirates!'
; ββ @ array.jl:200 within `length'
movq	8(%r15), %rax
; ββ
; β @ REPL[43]:5 within `logsumexp_sleefpirates!'
; ββ @ simdloop.jl:69 within `macro expansion'
; βββ @ range.jl:5 within `Colon'
; ββββ @ range.jl:275 within `Type'
; βββββ @ range.jl:280 within `unitrange_last'
movq	%rax, %rcx
sarq	\$63, %rcx
andnq	%rax, %rcx, %r8
; βββββ
; ββ @ checked.jl:194 within `macro expansion'
leaq	-1(%r8), %rsi
; ββ
; ββ @ simdloop.jl:71 within `macro expansion'
; βββ @ simdloop.jl:51 within `simd_inner_length'
; ββββ @ range.jl:541 within `length'
; βββββ @ checked.jl:165 within `checked_add'
; ββββββ @ checked.jl:132 within `add_with_overflow'
movq	%rsi, %r9
incq	%r9
; ββββββ
; βββββ @ checked.jl:166 within `checked_add'
jo	L2278
; βββββ
; ββ @ int.jl:49 within `macro expansion'
testq	%r9, %r9
vmovapd	%xmm0, 32(%rsp)
; ββ
; β @ simdloop.jl:72 within `logsumexp_sleefpirates!'
jle	L107
movq	(%r15), %rcx
movq	(%r14), %rdx
vxorpd	%xmm19, %xmm19, %xmm19
; β @ simdloop.jl:75 within `logsumexp_sleefpirates!'
cmpq	\$32, %r8
jae	L118
xorl	%esi, %esi
jmp	L1345
L107:
vxorpd	%xmm19, %xmm19, %xmm19
jmp	L1757
; β @ simdloop.jl:75 within `logsumexp_sleefpirates!'
L118:
leaq	(%rcx,%r8,8), %rsi
cmpq	%rsi, %rdx
jae	L143
leaq	(%rdx,%r8,8), %rsi
; β @ simdloop.jl:75 within `logsumexp_sleefpirates!'
cmpq	%rsi, %rcx
jae	L143
xorl	%esi, %esi
jmp	L1345
; β @ simdloop.jl:75 within `logsumexp_sleefpirates!'
L143:
movabsq	\$9223372036854775776, %rsi # imm = 0x7FFFFFFFFFFFFFE0
andq	%r8, %rsi
vmovupd	%zmm0, 64(%rsp)
vxorpd	%xmm0, %xmm0, %xmm0
xorl	%edi, %edi
movabsq	\$139756740332128, %rbx  # imm = 0x7F1BA6DCCA60
vmovups	%zmm1, 128(%rsp)
movabsq	\$139756740332136, %rbx  # imm = 0x7F1BA6DCCA68
movabsq	\$139756740332144, %rbx  # imm = 0x7F1BA6DCCA70
movabsq	\$139756740332152, %rbx  # imm = 0x7F1BA6DCCA78
movabsq	\$139756740332160, %rbx  # imm = 0x7F1BA6DCCA80
movabsq	\$139756740332168, %rbx  # imm = 0x7F1BA6DCCA88
movabsq	\$139756740332176, %rbx  # imm = 0x7F1BA6DCCA90
movabsq	\$139756740332184, %rbx  # imm = 0x7F1BA6DCCA98
movabsq	\$139756740332192, %rbx  # imm = 0x7F1BA6DCCAA0
movabsq	\$139756740332200, %rbx  # imm = 0x7F1BA6DCCAA8
movabsq	\$139756740332208, %rbx  # imm = 0x7F1BA6DCCAB0
movabsq	\$139756740332216, %rbx  # imm = 0x7F1BA6DCCAB8
movabsq	\$139756740332224, %rbx  # imm = 0x7F1BA6DCCAC0
movabsq	\$139756740332232, %rbx  # imm = 0x7F1BA6DCCAC8
movabsq	\$139756740332240, %rbx  # imm = 0x7F1BA6DCCAD0
vxorpd	%xmm18, %xmm18, %xmm18
vxorpd	%xmm19, %xmm19, %xmm19
vxorpd	%xmm20, %xmm20, %xmm20
; β
; β @ REPL[43]:5 within `logsumexp_sleefpirates!'
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:6
; βββ @ array.jl:728 within `getindex'
L448:
vmovupd	(%rcx,%rdi,8), %zmm5
vmovupd	64(%rcx,%rdi,8), %zmm21
vmovupd	128(%rcx,%rdi,8), %zmm22
vmovupd	192(%rcx,%rdi,8), %zmm23
vmovupd	64(%rsp), %zmm2
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ float.jl:397
vsubpd	%zmm2, %zmm5, %zmm5
vsubpd	%zmm2, %zmm21, %zmm29
vsubpd	%zmm2, %zmm22, %zmm30
vsubpd	%zmm2, %zmm23, %zmm31
vmovupd	128(%rsp), %zmm2
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:7
; βββ @ exp.jl:181 within `exp'
; ββββ @ float.jl:399 within `*'
vmulpd	%zmm2, %zmm5, %zmm21
vmulpd	%zmm2, %zmm29, %zmm22
vmulpd	%zmm2, %zmm30, %zmm23
vmulpd	%zmm2, %zmm31, %zmm24
; ββββ
; βββ @ float.jl:370 within `exp'
vrndscalepd	\$4, %zmm21, %zmm25
vrndscalepd	\$4, %zmm22, %zmm26
vrndscalepd	\$4, %zmm23, %zmm27
vrndscalepd	\$4, %zmm24, %zmm28
; βββ
; βββ @ exp.jl:183 within `exp'
; ββββ @ float.jl:304 within `unsafe_trunc'
vcvttpd2qq	%zmm25, %zmm21
vcvttpd2qq	%zmm26, %zmm22
vcvttpd2qq	%zmm27, %zmm23
vcvttpd2qq	%zmm28, %zmm24
; ββββ
; βββ @ float.jl:404 within `exp'
vfnmadd231pd	%zmm3, %zmm25, %zmm5
vfnmadd231pd	%zmm3, %zmm26, %zmm29
vfnmadd231pd	%zmm3, %zmm27, %zmm30
vfnmadd231pd	%zmm3, %zmm28, %zmm31
; βββ
; βββ @ exp.jl:186 within `exp'
; ββββ @ float.jl:404 within `muladd'
vfnmadd213pd	%zmm5, %zmm4, %zmm25
vfnmadd213pd	%zmm29, %zmm4, %zmm26
vfnmadd213pd	%zmm30, %zmm4, %zmm27
vfnmadd213pd	%zmm31, %zmm4, %zmm28
; ββββ
; βββ @ exp.jl:188 within `exp'
; ββββ @ exp.jl:161 within `exp_kernel'
; βββββ @ math.jl:101 within `macro expansion'
; ββββββ @ float.jl:404 within `muladd'
vmovapd	%zmm1, %zmm29
vfmadd213pd	%zmm6, %zmm25, %zmm29
vmovapd	%zmm1, %zmm30
vfmadd213pd	%zmm6, %zmm26, %zmm30
vmovapd	%zmm1, %zmm31
vfmadd213pd	%zmm6, %zmm27, %zmm31
vmovapd	%zmm1, %zmm5
vfmadd213pd	%zmm6, %zmm28, %zmm5
vfmadd213pd	%zmm7, %zmm25, %zmm29
vfmadd213pd	%zmm7, %zmm26, %zmm30
vfmadd213pd	%zmm7, %zmm27, %zmm31
vfmadd213pd	%zmm7, %zmm28, %zmm5
vfmadd213pd	%zmm8, %zmm25, %zmm29
vfmadd213pd	%zmm8, %zmm26, %zmm30
vfmadd213pd	%zmm8, %zmm27, %zmm31
vfmadd213pd	%zmm8, %zmm28, %zmm5
vfmadd213pd	%zmm9, %zmm25, %zmm29
vfmadd213pd	%zmm9, %zmm26, %zmm30
vfmadd213pd	%zmm9, %zmm27, %zmm31
vfmadd213pd	%zmm9, %zmm28, %zmm5
vfmadd213pd	%zmm10, %zmm25, %zmm29
vfmadd213pd	%zmm10, %zmm26, %zmm30
vfmadd213pd	%zmm10, %zmm27, %zmm31
vfmadd213pd	%zmm10, %zmm28, %zmm5
vfmadd213pd	%zmm11, %zmm25, %zmm29
vfmadd213pd	%zmm11, %zmm26, %zmm30
vfmadd213pd	%zmm11, %zmm27, %zmm31
vfmadd213pd	%zmm11, %zmm28, %zmm5
vfmadd213pd	%zmm12, %zmm25, %zmm29
vfmadd213pd	%zmm12, %zmm26, %zmm30
vfmadd213pd	%zmm12, %zmm27, %zmm31
vfmadd213pd	%zmm12, %zmm28, %zmm5
vfmadd213pd	%zmm13, %zmm25, %zmm29
vfmadd213pd	%zmm13, %zmm26, %zmm30
vfmadd213pd	%zmm13, %zmm27, %zmm31
vfmadd213pd	%zmm13, %zmm28, %zmm5
vfmadd213pd	%zmm14, %zmm25, %zmm29
vfmadd213pd	%zmm14, %zmm26, %zmm30
vfmadd213pd	%zmm14, %zmm27, %zmm31
vfmadd213pd	%zmm14, %zmm28, %zmm5
vfmadd213pd	%zmm15, %zmm25, %zmm29
vfmadd213pd	%zmm15, %zmm26, %zmm30
vfmadd213pd	%zmm15, %zmm27, %zmm31
vfmadd213pd	%zmm15, %zmm28, %zmm5
; ββββββ
; βββ @ float.jl:399 within `exp'
vmulpd	%zmm25, %zmm25, %zmm2
vmulpd	%zmm29, %zmm2, %zmm2
vmulpd	%zmm26, %zmm26, %zmm29
vmulpd	%zmm30, %zmm29, %zmm29
vmulpd	%zmm27, %zmm27, %zmm30
vmulpd	%zmm31, %zmm30, %zmm30
vmulpd	%zmm28, %zmm28, %zmm31
vmulpd	%zmm5, %zmm31, %zmm5
; βββ
; βββ @ exp.jl:189 within `exp'
; ββββ @ operators.jl:529 within `+' @ float.jl:395
vaddpd	%zmm2, %zmm25, %zmm2
vaddpd	%zmm29, %zmm26, %zmm25
vaddpd	%zmm30, %zmm27, %zmm26
vaddpd	%zmm5, %zmm28, %zmm5
; ββββ
; ββββ @ float.jl:395 within `+'
vaddpd	%zmm16, %zmm2, %zmm2
vaddpd	%zmm16, %zmm25, %zmm25
vaddpd	%zmm16, %zmm26, %zmm26
vaddpd	%zmm16, %zmm5, %zmm5
; ββββ
; βββ @ exp.jl:190 within `exp'
; ββββ @ priv.jl:51 within `ldexp2k'
; βββββ @ int.jl:444 within `>>' @ int.jl:437
vpsraq	\$1, %zmm21, %zmm27
vpsraq	\$1, %zmm22, %zmm28
vpsraq	\$1, %zmm23, %zmm29
vpsraq	\$1, %zmm24, %zmm30
; βββββ
; ββββ @ priv.jl:52 within `ldexp2k'
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm27, %zmm31
vpaddq	%zmm17, %zmm31, %zmm31
; βββββββ
; ββββ @ float.jl:399 within `ldexp2k'
vmulpd	%zmm31, %zmm2, %zmm2
; ββββ
; ββββ @ priv.jl:52 within `ldexp2k'
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm28, %zmm31
vpaddq	%zmm17, %zmm31, %zmm31
; βββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulpd	%zmm31, %zmm25, %zmm25
; ββββββ
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm29, %zmm31
vpaddq	%zmm17, %zmm31, %zmm31
; βββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulpd	%zmm31, %zmm26, %zmm26
; ββββββ
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm30, %zmm31
vpaddq	%zmm17, %zmm31, %zmm31
; βββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulpd	%zmm31, %zmm5, %zmm5
; ββββββ
; βββββ @ int.jl:52 within `-'
vpsubq	%zmm27, %zmm21, %zmm21
vpsubq	%zmm28, %zmm22, %zmm22
vpsubq	%zmm29, %zmm23, %zmm23
vpsubq	%zmm30, %zmm24, %zmm24
; βββββ
; ββββ @ int.jl:439 within `ldexp2k'
vpsllq	\$52, %zmm21, %zmm21
vpaddq	%zmm17, %zmm21, %zmm21
; ββββ
; ββββ @ priv.jl:52 within `ldexp2k'
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulpd	%zmm21, %zmm2, %zmm2
; ββββββ
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm22, %zmm21
vpaddq	%zmm17, %zmm21, %zmm21
; βββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulpd	%zmm21, %zmm25, %zmm21
; ββββββ
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm23, %zmm22
vpaddq	%zmm17, %zmm22, %zmm22
; βββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulpd	%zmm22, %zmm26, %zmm22
; ββββββ
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
vpsllq	\$52, %zmm24, %zmm23
vpaddq	%zmm17, %zmm23, %zmm23
; βββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulpd	%zmm23, %zmm5, %zmm5
; ββββββ
; ββ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovupd	%zmm2, (%rdx,%rdi,8)
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; βββ @ float.jl:395 within `+'
vaddpd	%zmm2, %zmm0, %zmm0
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; βββ @ array.jl:766 within `setindex!'
vmovupd	%zmm21, 64(%rdx,%rdi,8)
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; βββ @ float.jl:395 within `+'
vaddpd	%zmm21, %zmm18, %zmm18
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; βββ @ array.jl:766 within `setindex!'
vmovupd	%zmm22, 128(%rdx,%rdi,8)
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; βββ @ float.jl:395 within `+'
vaddpd	%zmm22, %zmm19, %zmm19
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; βββ @ array.jl:766 within `setindex!'
vmovupd	%zmm5, 192(%rdx,%rdi,8)
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; βββ @ float.jl:395 within `+'
vaddpd	%zmm5, %zmm20, %zmm20
; βββ
; ββ @ int.jl:53 within `macro expansion'
cmpq	%rdi, %rsi
jne	L448
; ββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:9
; βββ @ float.jl:395 within `+'
vaddpd	%zmm0, %zmm18, %zmm0
vaddpd	%zmm0, %zmm19, %zmm0
vaddpd	%zmm0, %zmm20, %zmm0
vextractf64x4	\$1, %zmm0, %ymm1
vaddpd	%zmm1, %zmm0, %zmm0
vextractf128	\$1, %ymm0, %xmm1
vaddpd	%zmm1, %zmm0, %zmm0
vpermilpd	\$1, %xmm0, %xmm1 # xmm1 = xmm0[1,0]
vaddpd	%zmm1, %zmm0, %zmm19
cmpq	%rsi, %r8
vmovapd	32(%rsp), %xmm0
; βββ
; β @ simdloop.jl:75 within `logsumexp_sleefpirates!'
je	L1757
L1345:
movabsq	\$4607182418800017408, %rdi # imm = 0x3FF0000000000000
movabsq	\$139756740332128, %rbx  # imm = 0x7F1BA6DCCA60
vmovsd	(%rbx), %xmm8           # xmm8 = mem[0],zero
movabsq	\$139756740332248, %rbx  # imm = 0x7F1BA6DCCAD8
vmovsd	(%rbx), %xmm9           # xmm9 = mem[0],zero
movabsq	\$139756740332256, %rbx  # imm = 0x7F1BA6DCCAE0
vmovsd	(%rbx), %xmm10          # xmm10 = mem[0],zero
movabsq	\$139756740332152, %rbx  # imm = 0x7F1BA6DCCA78
vmovsd	(%rbx), %xmm11          # xmm11 = mem[0],zero
movabsq	\$139756740332160, %rbx  # imm = 0x7F1BA6DCCA80
vmovsd	(%rbx), %xmm12          # xmm12 = mem[0],zero
movabsq	\$139756740332168, %rbx  # imm = 0x7F1BA6DCCA88
vmovsd	(%rbx), %xmm13          # xmm13 = mem[0],zero
movabsq	\$139756740332176, %rbx  # imm = 0x7F1BA6DCCA90
vmovsd	(%rbx), %xmm14          # xmm14 = mem[0],zero
movabsq	\$139756740332184, %rbx  # imm = 0x7F1BA6DCCA98
vmovsd	(%rbx), %xmm15          # xmm15 = mem[0],zero
movabsq	\$139756740332192, %rbx  # imm = 0x7F1BA6DCCAA0
vmovsd	(%rbx), %xmm16          # xmm16 = mem[0],zero
movabsq	\$139756740332200, %rbx  # imm = 0x7F1BA6DCCAA8
vmovsd	(%rbx), %xmm17          # xmm17 = mem[0],zero
movabsq	\$139756740332208, %rbx  # imm = 0x7F1BA6DCCAB0
vmovsd	(%rbx), %xmm18          # xmm18 = mem[0],zero
movabsq	\$139756740332216, %rbx  # imm = 0x7F1BA6DCCAB8
vmovsd	(%rbx), %xmm3           # xmm3 = mem[0],zero
movabsq	\$139756740332224, %rbx  # imm = 0x7F1BA6DCCAC0
vmovsd	(%rbx), %xmm4           # xmm4 = mem[0],zero
movabsq	\$139756740332232, %rbx  # imm = 0x7F1BA6DCCAC8
vmovsd	(%rbx), %xmm5           # xmm5 = mem[0],zero
movabsq	\$139756740332240, %rbx  # imm = 0x7F1BA6DCCAD0
vmovsd	(%rbx), %xmm6           # xmm6 = mem[0],zero
nopw	%cs:(%rax,%rax)
; β
; β @ REPL[43]:5 within `logsumexp_sleefpirates!'
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:6
; βββ @ array.jl:728 within `getindex'
L1584:
vmovsd	(%rcx,%rsi,8), %xmm7    # xmm7 = mem[0],zero
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ float.jl:397
vsubsd	%xmm0, %xmm7, %xmm7
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:7
; βββ @ exp.jl:181 within `exp'
; ββββ @ float.jl:399 within `*'
vmulsd	%xmm8, %xmm7, %xmm1
; ββββ
; ββββ @ floatfuncs.jl:129 within `round' @ float.jl:370
vrndscalesd	\$4, %xmm1, %xmm1, %xmm1
; ββββ
; βββ @ float.jl:304 within `exp'
vcvttsd2si	%xmm1, %rax
; βββ
; βββ @ exp.jl:185 within `exp'
; ββββ @ float.jl:404 within `muladd'
vfmadd231sd	%xmm9, %xmm1, %xmm7
; ββββ
; βββ @ exp.jl:186 within `exp'
; ββββ @ float.jl:404 within `muladd'
vfmadd231sd	%xmm10, %xmm1, %xmm7
; ββββ
; βββ @ exp.jl:188 within `exp'
; ββββ @ exp.jl:161 within `exp_kernel'
; βββββ @ math.jl:101 within `macro expansion'
; ββββββ @ float.jl:404 within `muladd'
vmovapd	%xmm11, %xmm2
vfmadd213sd	%xmm12, %xmm7, %xmm2
vfmadd213sd	%xmm13, %xmm7, %xmm2
vfmadd213sd	%xmm14, %xmm7, %xmm2
vfmadd213sd	%xmm15, %xmm7, %xmm2
vfmadd213sd	%xmm16, %xmm7, %xmm2
vfmadd213sd	%xmm17, %xmm7, %xmm2
vfmadd213sd	%xmm18, %xmm7, %xmm2
vfmadd213sd	%xmm3, %xmm7, %xmm2
vfmadd213sd	%xmm4, %xmm7, %xmm2
vfmadd213sd	%xmm5, %xmm7, %xmm2
; ββββββ
; βββ @ float.jl:399 within `exp'
vmulsd	%xmm7, %xmm7, %xmm1
vmulsd	%xmm2, %xmm1, %xmm1
; βββ
; βββ @ exp.jl:189 within `exp'
; ββββ @ operators.jl:529 within `+' @ float.jl:395
vaddsd	%xmm1, %xmm7, %xmm1
; ββββ
; ββββ @ float.jl:395 within `+'
vaddsd	%xmm6, %xmm1, %xmm2
; ββββ
; βββ @ int.jl:437 within `exp'
movq	%rax, %rbx
sarq	%rbx
; βββ
; βββ @ exp.jl:190 within `exp'
; ββββ @ int.jl:52 within `ldexp2k'
subl	%ebx, %eax
; ββββ
; ββββ @ priv.jl:52 within `ldexp2k'
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
shlq	\$52, %rbx
; βββββββ
; ββββββ @ essentials.jl:417 within `integer2float'
vmovq	%rbx, %xmm1
; ββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulsd	%xmm1, %xmm2, %xmm2
; ββββββ
; βββββ @ utils.jl:51 within `pow2i'
; ββββββ @ utils.jl:20 within `integer2float'
; βββββββ @ int.jl:446 within `<<' @ int.jl:439
shlq	\$52, %rax
; βββββββ
; βββββββ @ essentials.jl:417 within `reinterpret'
vmovq	%rax, %xmm1
; βββββββ
; βββββ @ floating_point_arithmetic.jl:62 within `evmul'
; ββββββ @ float.jl:399 within `*'
vmulsd	%xmm1, %xmm2, %xmm1
; ββββββ
; ββ @ simdloop.jl:77 within `macro expansion' @ REPL[43]:8
; βββ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rdx,%rsi,8)
; βββ
; ββ @ simdloop.jl:77 within `macro expansion' @ float.jl:395
vaddsd	%xmm1, %xmm19, %xmm19
; ββ @ simdloop.jl:78 within `macro expansion'
; βββ @ int.jl:53 within `+'
; βββ
; ββ @ simdloop.jl:75 within `macro expansion'
; βββ @ int.jl:49 within `<'
cmpq	%r9, %rsi
; βββ
jb	L1584
; ββ
; β @ REPL[43]:11 within `logsumexp_sleefpirates!'
L1757:
movabsq	\$"<;", %rax
vmovupd	%zmm19, 64(%rsp)
vmovapd	%xmm19, %xmm0
vzeroupper
callq	*%rax
; ββ @ broadcast.jl:801 within `materialize!'
; βββ @ abstractarray.jl:75 within `axes'
; ββββ @ array.jl:155 within `size'
movq	24(%r15), %rdx
; ββββ
; βββ @ promotion.jl:414 within `axes'
testq	%rdx, %rdx
; βββ
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:887
; ββββ @ simdloop.jl:72 within `macro expansion'
jle	L2039
; ββββ
; β @ simdloop.jl within `logsumexp_sleefpirates!'
movq	%rdx, %rax
sarq	\$63, %rax
andnq	%rdx, %rax, %rax
; β
; β @ REPL[43]:11 within `logsumexp_sleefpirates!'
; ββ @ float.jl:395 within `+'
vaddsd	32(%rsp), %xmm0, %xmm0
movq	(%r15), %rcx
; ββ
; ββ @ broadcast.jl:801 within `materialize!'
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:886
; ββββ @ broadcast.jl:869 within `preprocess'
; βββββ @ broadcast.jl:872 within `preprocess_args'
; ββββββ @ broadcast.jl:870 within `preprocess'
; βββββββ @ broadcast.jl:592 within `extrude'
; ββββββββ @ broadcast.jl:547 within `newindexer'
; βββββββββ @ broadcast.jl:548 within `shapeindexer'
; ββββββββββ @ broadcast.jl:553 within `_newindexer'
; βββββββββββ @ operators.jl:193 within `!='
; ββββββββββββ @ promotion.jl:403 within `=='
cmpq	\$1, %rdx
; ββββββββββββ
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:887
; ββββ @ simdloop.jl:75 within `macro expansion'
jne	L1867
xorl	%edx, %edx
nopw	%cs:(%rax,%rax)
; ββββ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; βββββ @ broadcast.jl:558 within `getindex'
; βββββββ @ broadcast.jl:621 within `_getindex'
; βββββββββ @ array.jl:728 within `getindex'
L1840:
vmovsd	(%rcx), %xmm1           # xmm1 = mem[0],zero
; βββββββββ
; ββββββββ @ float.jl:397 within `-'
vsubsd	%xmm0, %xmm1, %xmm1
; ββββββββ
; βββββ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rcx,%rdx,8)
; βββββ
; ββββ @ int.jl:53 within `macro expansion'
; ββββ
; ββββ @ simdloop.jl:75 within `macro expansion'
; βββββ @ int.jl:49 within `<'
cmpq	%rax, %rdx
; βββββ
jb	L1840
jmp	L2039
L1867:
cmpq	\$32, %rax
jae	L1880
xorl	%edx, %edx
jmp	L2016
; ββββ @ simdloop.jl:75 within `macro expansion'
L1880:
movabsq	\$9223372036854775776, %rdx # imm = 0x7FFFFFFFFFFFFFE0
andq	%rax, %rdx
leaq	192(%rcx), %rsi
; ββββ @ simdloop.jl:78 within `macro expansion'
; βββββ @ int.jl:53 within `+'
movq	%rdx, %rdi
nopw	%cs:(%rax,%rax)
; βββββ
; ββββ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; βββββ @ broadcast.jl:558 within `getindex'
; βββββββ @ broadcast.jl:621 within `_getindex'
; βββββββββ @ array.jl:728 within `getindex'
L1920:
vmovupd	-192(%rsi), %zmm2
vmovupd	-128(%rsi), %zmm3
vmovupd	-64(%rsi), %zmm4
vmovupd	(%rsi), %zmm5
; βββββββββ
; ββββββββ @ float.jl:397 within `-'
vsubpd	%zmm1, %zmm2, %zmm2
vsubpd	%zmm1, %zmm3, %zmm3
vsubpd	%zmm1, %zmm4, %zmm4
vsubpd	%zmm1, %zmm5, %zmm5
; ββββββββ
; ββββ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovupd	%zmm2, -192(%rsi)
vmovupd	%zmm3, -128(%rsi)
vmovupd	%zmm4, -64(%rsi)
vmovupd	%zmm5, (%rsi)
; ββββ @ simdloop.jl:78 within `macro expansion'
; βββββ @ int.jl:53 within `+'
addq	\$256, %rsi              # imm = 0x100
jne	L1920
; βββββ
; β @ int.jl within `logsumexp_sleefpirates!'
cmpq	%rdx, %rax
; β
; β @ simdloop.jl:75 within `logsumexp_sleefpirates!'
je	L2039
; β
; β @ REPL[43]:11 within `logsumexp_sleefpirates!'
; ββ @ broadcast.jl:801 within `materialize!'
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:887
; ββββ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; βββββ @ broadcast.jl:558 within `getindex'
; βββββββ @ broadcast.jl:621 within `_getindex'
; βββββββββ @ array.jl:728 within `getindex'
L2016:
vmovsd	(%rcx,%rdx,8), %xmm1    # xmm1 = mem[0],zero
; βββββββββ
; ββββββ @ float.jl:397 within `_broadcast_getindex'
vsubsd	%xmm0, %xmm1, %xmm1
; ββββββ
; βββββ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rcx,%rdx,8)
; βββββ
; ββββ @ simdloop.jl:78 within `macro expansion'
; βββββ @ int.jl:53 within `+'
; βββββ
; ββββ @ simdloop.jl:75 within `macro expansion'
; βββββ @ int.jl:49 within `<'
cmpq	%rax, %rdx
jb	L2016
; βββββ
; β @ REPL[43]:12 within `logsumexp_sleefpirates!'
; ββ @ broadcast.jl:801 within `materialize!'
; βββ @ abstractarray.jl:75 within `axes'
; ββββ @ array.jl:155 within `size'
L2039:
movq	24(%r14), %rdx
; ββββ
; βββ @ promotion.jl:414 within `axes'
testq	%rdx, %rdx
; βββ
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:887
; ββββ @ simdloop.jl:72 within `macro expansion'
jle	L2259
; ββββ
; β @ simdloop.jl within `logsumexp_sleefpirates!'
movq	%rdx, %rax
sarq	\$63, %rax
andnq	%rdx, %rax, %rax
movabsq	\$139756740332240, %rcx  # imm = 0x7F1BA6DCCAD0
; β
; β @ REPL[43]:12 within `logsumexp_sleefpirates!'
; ββ @ promotion.jl:316 within `/' @ float.jl:401
vmovsd	(%rcx), %xmm0           # xmm0 = mem[0],zero
vdivsd	64(%rsp), %xmm0, %xmm0
movq	(%r14), %rcx
; ββ
; ββ @ broadcast.jl:801 within `materialize!'
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:886
; ββββ @ broadcast.jl:869 within `preprocess'
; βββββ @ broadcast.jl:872 within `preprocess_args'
; ββββββ @ broadcast.jl:870 within `preprocess'
; βββββββ @ broadcast.jl:592 within `extrude'
; ββββββββ @ broadcast.jl:547 within `newindexer'
; βββββββββ @ broadcast.jl:548 within `shapeindexer'
; ββββββββββ @ broadcast.jl:553 within `_newindexer'
; βββββββββββ @ operators.jl:193 within `!='
; ββββββββββββ @ promotion.jl:403 within `=='
cmpq	\$1, %rdx
; ββββββββββββ
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:887
; ββββ @ simdloop.jl:75 within `macro expansion'
jne	L2119
xorl	%edx, %edx
nop
; ββββ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; βββββ @ broadcast.jl:558 within `getindex'
; ββββββββ @ float.jl:399 within `*'
L2096:
vmulsd	(%rcx), %xmm0, %xmm1
; ββββββββ
; βββββ @ array.jl:766 within `setindex!'
vmovsd	%xmm1, (%rcx,%rdx,8)
; βββββ
; ββββ @ int.jl:53 within `macro expansion'
; ββββ
; ββββ @ simdloop.jl:75 within `macro expansion'
; βββββ @ int.jl:49 within `<'
cmpq	%rax, %rdx
; βββββ
jb	L2096
jmp	L2259
L2119:
cmpq	\$32, %rax
jae	L2129
xorl	%edx, %edx
jmp	L2240
; ββββ @ simdloop.jl:75 within `macro expansion'
L2129:
movabsq	\$9223372036854775776, %rdx # imm = 0x7FFFFFFFFFFFFFE0
andq	%rax, %rdx
leaq	192(%rcx), %rsi
; ββββ @ simdloop.jl:78 within `macro expansion'
; βββββ @ int.jl:53 within `+'
movq	%rdx, %rdi
nop
; βββββ
; ββββ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; βββββ @ broadcast.jl:558 within `getindex'
; ββββββββ @ float.jl:399 within `*'
L2160:
vmulpd	-192(%rsi), %zmm1, %zmm2
vmulpd	-128(%rsi), %zmm1, %zmm3
vmulpd	-64(%rsi), %zmm1, %zmm4
vmulpd	(%rsi), %zmm1, %zmm5
; ββββββββ
; ββββ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovupd	%zmm2, -192(%rsi)
vmovupd	%zmm3, -128(%rsi)
vmovupd	%zmm4, -64(%rsi)
vmovupd	%zmm5, (%rsi)
; ββββ @ simdloop.jl:78 within `macro expansion'
; βββββ @ int.jl:53 within `+'
addq	\$256, %rsi              # imm = 0x100
jne	L2160
; βββββ
; β @ int.jl within `logsumexp_sleefpirates!'
cmpq	%rdx, %rax
; β
; β @ simdloop.jl:75 within `logsumexp_sleefpirates!'
je	L2259
nopl	(%rax,%rax)
; β
; β @ REPL[43]:12 within `logsumexp_sleefpirates!'
; ββ @ broadcast.jl:801 within `materialize!'
; βββ @ broadcast.jl:842 within `copyto!' @ broadcast.jl:887
; ββββ @ simdloop.jl:77 within `macro expansion' @ broadcast.jl:888
; βββββ @ broadcast.jl:558 within `getindex'
; ββββββββ @ float.jl:399 within `*'
L2240:
vmulsd	(%rcx,%rdx,8), %xmm0, %xmm1
; ββββββββ
; ββββ @ simdloop.jl:77 within `macro expansion' @ array.jl:766
vmovsd	%xmm1, (%rcx,%rdx,8)
; ββββ @ simdloop.jl:78 within `macro expansion'
; βββββ @ int.jl:53 within `+'
; βββββ
; ββββ @ simdloop.jl:75 within `macro expansion'
; βββββ @ int.jl:49 within `<'
cmpq	%rax, %rdx
jb	L2240
; βββββ
L2259:
movq	%r14, %rax
popq	%rbx
popq	%r14
popq	%r15
vzeroupper
retq
; β @ REPL[43]:5 within `logsumexp_sleefpirates!'
; ββ @ simdloop.jl:71 within `macro expansion'
; βββ @ simdloop.jl:51 within `simd_inner_length'
; ββββ @ range.jl:541 within `length'
; βββββ @ checked.jl:166 within `checked_add'
L2278:
movabsq	\$throw_overflowerr_binaryop, %rax
movabsq	\$139757138773328, %rdi  # imm = 0x7F1BBE9C8550
movl	\$1, %edx
callq	*%rax
ud2
nopw	%cs:(%rax,%rax)
; βββββ
``````

SLEEFPirates is vectorized.

8 Likes

Thanks for this! On my machines, I still cannot beat Yeppp on `Float64`, but with SLEEFPirates and LoopVectorization I get significantly closer.

``````julia> w = randn(Float64, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
10.358 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_yeppp!(\$w,\$we);
1.996 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_sleefpirates!(\$w,\$we);
3.157 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_loopvec!(\$w,\$we);
3.575 ΞΌs (0 allocations: 0 bytes)
julia> w = randn(Float32, N); we = similar(w);
julia> @btime logsumexp!(\$w,\$we);
7.773 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_sleefpirates!(\$w,\$we);
1.580 ΞΌs (0 allocations: 0 bytes)
julia> @btime logsumexp_loopvec!(\$w,\$we);
ERROR: MethodError: no method matching vload(::Type{SVec{4,Float64}}, ::VectorizationBase.vpointer{Float32})
Closest candidates are:
``````
1 Like
``````@generated function logsumexp_loopvectorization!(w::Vector{T},we) where T
quote
offset = maximum(w)
N = length(w)
s = zero(T)
@vectorize \$T for i = 1:N
wl = w[i]
wel = exp(wl-offset)
we[i] = wel
s += wel
end
w  .-= log(s) + offset
we .*= 1/s
end
end
``````

`@vectorize` accepts a type argument, but it canβt be a symbol. Using a generated function to insert it works.

``````julia> @btime logsumexp_sleefpirates!(\$w,\$we);
1.474 ΞΌs (0 allocations: 0 bytes)

julia> @btime logsumexp_loopvectorization!(\$w,\$we);
1.325 ΞΌs (0 allocations: 0 bytes)
``````

Yeppβs performance is unfortunately inconsistent across architectures. It was even slower than the totally unvectorized version on my computer.
Iβm guessing itβs a precompiled binary that dispatches based on recognizing the host CPU.
Before OpenBLAS had avx512 support, it just used avx2 kernels. While slower, thatβs much faster than just falling back to the most generic code!

4 Likes

I should comment that the offset algorithm that you are using for logsumexp is potentially problematic β you can easily contrive a case where it is off by a factor of 2. In particular, try `[1e-20, log(1e-20)]` with your logsumexp algorithm

``````function f(x)
X = maximum(x)
return X + log(sum(exp.(x .- X)))
end
``````

You get:

``````julia> x = [1e-20, log(1e-20)];

julia> f(x)                  # inaccurate!
1.0e-20

julia> Float64(f(big.(x)))   # accurate
1.9999999999999993e-20
``````

A possible fix is to pull the maximum `x` term out of the sum and use the `log1p` function. I actually assigned this as an exam question recently, so you can see the explanation in my problem 3 solutions.

Not sure if this matters for the machine-learning application, however, since there you are adding the logsumexp to a posterior probability and so errors in tiny values like this may get rounded away in your final result.

8 Likes

Thanks for the link! I had not considered it so carefully before, but will sure make use of this trick in the future.

For future reference, here is my implementation

1 Like

It looks like youβre mostly using `logsumexp!` for sampling from a categorical, is that right? Have you tried using the Gumbel max trick for this? Itβs pretty great for sampling a categorical with unnormalized log-probability weights.

1 Like

That trick does not return the log likelihood like the calculation above, right? Also, they mention that all the log(log(U)) can be precalculated, doesnβt this cause some bias?

How could one link LLVM with SVML as you suggest above?

βLog-likelihoodβ doesnβt make sense unless thereβs context of some distribution. Do you need the log-likelihood of a Dirichlet? Then no, I donβt think the Gumbel trick would help.

Iβve seen logsumexp in a few different contexts, and Iβm not sure which one youβre working in. It had seemed you were assigning a log-likelihood to each particle, and were then concerned about normalizing these in order to draw a categorical. In that case my point was that the Gumbel trick will let you skip the normalization. [βreturning the log-likelihoodβ would be trivial in this case, so Iβm guessing this isnβt what you meant]

If you re-use values, yes. But I donβt think anyoneβs suggesting that. As I see it, the benefit is to do an equivalent calculation with fewer dependencies. This means you could compute the Gumbels in parallel, or have a thread dedicated to keeping a supply available, or lots of other tricks.

Iβve also wondered sometimes about things like `log(log(U))`. We have fast algorithms for lots of things. If `doublelog` is critical for performance, shouldnβt we be computing it more directly, maybe a series expansion focused on U, rather than calling `log` twice?

Finally, I think youβre doing systematic sampling, right? This seems convenient for this approach, since quantiles for a uniform are dead-simple. It could also work well to pull in something like Sobol.jl if QMC is a possibility.

Thanks for this thread. I typically use a lot of `logsumexp` in my code, but never took the time to optimize it. I will try your variants. I expect to see some performance gains!

This is exactly what I do, but I also make use of the value `sum(exp.(w))` for calculating the likelihood of the data given the model parameters, so that must be calculated anyway. If I was not interested in this number, the trick might very well have sped up the calculations.

1 Like