Fast LogSumExp over 4th dimension

Should be fixed by this commit. I’ll tag a release after tests pass.

4 Likes

Seems to work with latest update :+1:

One other thing I am wondering. I made different version of function 8 for float32 and float16 as follows:

function logsumexp4D8(Arr4d::Array{Float64, 4})
    @tullio (max) max_[i,j,k] := Arr4d[i,j,k,l]
    @tullio _[i,j,k] := exp(Arr4d[i,j,k,l] - max_[i,j,k]) |> log(_) + max_[i,j,k]  
end

function logsumexp4D8b(Arr4d::Array{Float32, 4})
    @tullio (max) max_[i,j,k] := Arr4d[i,j,k,l]
    @tullio _[i,j,k] := exp(Arr4d[i,j,k,l] - max_[i,j,k]) |> log(_) + max_[i,j,k]  
end

function logsumexp4D8c(Arr4d::Array{Float16, 4})
    @tullio (max) max_[i,j,k] := Arr4d[i,j,k,l]
    @tullio _[i,j,k] := exp(Arr4d[i,j,k,l] - max_[i,j,k]) |> log(_) + max_[i,j,k]  
end

Make data and get timings:

A = rand(20, 1000, 80, 5)
B = convert(Array{Float32}, A)
C = convert(Array{Float16}, A)

julia> @btime logsumexp4D8($A);
  14.341 ms (234 allocations: 24.42 MiB)

julia> @btime logsumexp4D8b($B);
  7.375 ms (233 allocations: 12.22 MiB)

julia> @btime logsumexp4D8c($C);
  64.285 ms (232 allocations: 6.11 MiB)

Why is the Float16 version so slow? I thought it would be the fastest

Most CPUs don’t have native Float16 support, meaning they have to convert to another data type (e.g. Float32) to perform the actual operations.
LoopVectorization also doesn’t support Float16, but I probably should support some types of it as CPUs that do in fact support 16 bit floating point types are becoming more common.

3 Likes

Intel processors have supported FP16 as a storage format since Ivy Bridge, with the FC16 extension.

bfloat16 has just started to come out, currently only available on the Xeon Cooper Lake microprocessors.

1 Like

Cool, I didn’t know that.
But I was aware of bfloat16 in cooper lake and upcoming Saphire Rapids.
The A64FX also supports half precision, and the Neoverse V1 and N2 will have bfloat16.

So I have a recent intel based Mac with a corei9 - is there something I need to do to take advantage of this ?

I could add support for unpacking Float16 into Float32.

1 Like

Ahh I see - to support the FP16 you need to change the package itself? Well, I wouldn’t ask you to do major work for this - I just wanted to try it out, I’m honestly not sure if I could use FP32/FP16 in my code - that needs some thought

Can’t agree more. My PhD advisor is one of the world’s biggest experts on R. But I did my PhD in Julia. Did not even install R on my new laptop after coursework was done with! R is a nightmare that the stats community kept investing in.

Locally, I tried making a few changes to get the code to run.
It required changes to both VectorizationBase.jl (adding limited Float16 support through conversion to Float32) and Tullio.jl (use LoopVectorization.NativeTypes instead of Union{Base.HWReal,Bool}).
The Float16 version was slower than the Float64 version. Maybe some more changes can improve that situation, e.g. checking for suboptimal codegen. I don’t think performance should be as bad as it is.

1 Like

We currently have an LLVM pass that for every elementary operation promotes Float16s to Float32, does the arithmetic in regular single precision and truncates back to Float16 again. I believe this was done because LLVM otherwise didn’t always truncate in between operations on architectures only supporting single precision, leading to (more accurate but) wrong results. This is probably where this overhead over Float32 is coming from. It would likely be possible to disable that pass by using a custom AbstractInterpreter. (GPUCompiler.jl might already be doing that.)

2 Likes

I was actually intending to take that same approach here (not truncating), which I can do because all computation is done on VectorizationBase.AbstractSIMD, which define their own promotion rules, and I defined:

Base.promote(v1::AbstractSIMD{W,Float16}, v2::AbstractSIMD{W,Float16}) where {W} = (convert(Float32,v1), convert(Float32,v2))

And should thus be able to funnel arithmetic there by relying on the promoting fallback operators and not defining them for half.
But I haven’t checked thoroughly that I’m actually promoting.

If I don’t, it seems operations get scalarized:

@inline function my_add(a::NTuple{16,Core.VecElement{Float16}}, b::NTuple{16,Core.VecElement{Float16}}, c::NTuple{16,Core.VecElement{Float16}})
  ccall("llvm.fmuladd.v16f16", llvmcall, NTuple{16,Core.VecElement{Float16}}, (NTuple{16,Core.VecElement{Float16}},NTuple{16,Core.VecElement{Float16}},NTuple{16,Core.VecElement{Float16}}), a, b, c)
end
a = ntuple(_ -> Core.VecElement(rand(Float16)), Val(16));
@code_llvm debuginfo=:none my_add(a,a,a)
@code_native syntax=:intel debuginfo=:none my_add(a,a,a)

LLVM looks fine:

define <16 x half> @julia_my_add_615(<16 x half> %0, <16 x half> %1, <16 x half> %2) #0 {
top:
  %3 = call <16 x half> @llvm.fmuladd.v16f16(<16 x half> %0, <16 x half> %1, <16 x half> %2)
  ret <16 x half> %3
}

But the assembly:

        .text
        mov     rax, rdi
        movzx   edi, word ptr [rsp + 96]
        vmovd   xmm8, edi
        movzx   esi, si
        vmovd   xmm9, esi
        movzx   esi, word ptr [rsp + 224]
        vmovd   xmm11, esi
        movzx   esi, word ptr [rsp + 104]
        vmovd   xmm10, esi
        movzx   edx, dx
        vmovd   xmm12, edx
        movzx   edx, word ptr [rsp + 232]
        vmovd   xmm13, edx
        movzx   edx, word ptr [rsp + 112]
        vmovd   xmm14, edx
        movzx   ecx, cx
        vmovd   xmm17, ecx
        movzx   ecx, word ptr [rsp + 240]
        vmovd   xmm15, ecx
        movzx   ecx, word ptr [rsp + 120]
        vmovd   xmm16, ecx
        movzx   ecx, r8w
        vmovd   xmm18, ecx
        movzx   ecx, word ptr [rsp + 248]
        vmovd   xmm20, ecx
        movzx   ecx, word ptr [rsp + 128]
        vmovd   xmm19, ecx
        movzx   ecx, r9w
        vmovd   xmm21, ecx
        movzx   ecx, word ptr [rsp + 256]
        vmovd   xmm22, ecx
        movzx   ecx, word ptr [rsp + 136]
        vmovd   xmm23, ecx
        movzx   ecx, word ptr [rsp + 8]
        vcvtph2ps       xmm1, xmm8
        vcvtph2ps       xmm7, xmm9
        vmovd   xmm24, ecx
        movzx   ecx, word ptr [rsp + 264]
        vmulss  xmm1, xmm7, xmm1
        vcvtps2ph       xmm1, xmm1, 4
        vcvtph2ps       xmm1, xmm1
        vcvtph2ps       xmm7, xmm11
        vaddss  xmm1, xmm1, xmm7
        vcvtps2ph       xmm8, xmm1, 4
        vmovd   xmm26, ecx
        movzx   ecx, word ptr [rsp + 144]
        vcvtph2ps       xmm1, xmm10
        vcvtph2ps       xmm4, xmm12
        vmovd   xmm25, ecx
        movzx   ecx, word ptr [rsp + 16]
        vmulss  xmm1, xmm4, xmm1
        vcvtps2ph       xmm1, xmm1, 4
        vcvtph2ps       xmm1, xmm1
        vcvtph2ps       xmm4, xmm13
        vaddss  xmm1, xmm1, xmm4
        vcvtps2ph       xmm9, xmm1, 4
        vmovd   xmm13, ecx
        movzx   ecx, word ptr [rsp + 272]
        vcvtph2ps       xmm4, xmm14
        vcvtph2ps       xmm5, xmm17
        vmovd   xmm14, ecx
        movzx   ecx, word ptr [rsp + 152]
        vmulss  xmm4, xmm5, xmm4
        vcvtps2ph       xmm4, xmm4, 4
        vcvtph2ps       xmm4, xmm4
        vcvtph2ps       xmm5, xmm15
        vaddss  xmm4, xmm4, xmm5
        vcvtps2ph       xmm10, xmm4, 4
        vmovd   xmm17, ecx
        movzx   ecx, word ptr [rsp + 24]
        vcvtph2ps       xmm5, xmm16
        vcvtph2ps       xmm0, xmm18
        vmovd   xmm16, ecx
        movzx   ecx, word ptr [rsp + 280]
        vmulss  xmm0, xmm0, xmm5
        vcvtps2ph       xmm0, xmm0, 4
        vcvtph2ps       xmm0, xmm0
        vcvtph2ps       xmm5, xmm20
        vaddss  xmm0, xmm0, xmm5
        vcvtps2ph       xmm11, xmm0, 4
        vmovd   xmm18, ecx
        movzx   ecx, word ptr [rsp + 160]
        vcvtph2ps       xmm5, xmm19
        vcvtph2ps       xmm2, xmm21
        vmovd   xmm19, ecx
        movzx   ecx, word ptr [rsp + 32]
        vmulss  xmm2, xmm2, xmm5
        vcvtps2ph       xmm2, xmm2, 4
        vcvtph2ps       xmm2, xmm2
        vcvtph2ps       xmm5, xmm22
        vaddss  xmm2, xmm2, xmm5
        vcvtps2ph       xmm12, xmm2, 4
        vmovd   xmm2, ecx
        movzx   ecx, word ptr [rsp + 288]
        vcvtph2ps       xmm5, xmm23
        vcvtph2ps       xmm1, xmm24
        vmovd   xmm20, ecx
        movzx   ecx, word ptr [rsp + 168]
        vmulss  xmm1, xmm1, xmm5
        vcvtps2ph       xmm1, xmm1, 4
        vcvtph2ps       xmm1, xmm1
        vcvtph2ps       xmm5, xmm26
        vaddss  xmm1, xmm1, xmm5
        vcvtps2ph       xmm15, xmm1, 4
        vmovd   xmm5, ecx
        movzx   ecx, word ptr [rsp + 40]
        vcvtph2ps       xmm1, xmm25
        vcvtph2ps       xmm4, xmm13
        vmovd   xmm3, ecx
        movzx   ecx, word ptr [rsp + 296]
        vmulss  xmm1, xmm4, xmm1
        vcvtps2ph       xmm1, xmm1, 4
        vcvtph2ps       xmm1, xmm1
        vcvtph2ps       xmm4, xmm14
        vaddss  xmm1, xmm1, xmm4
        vcvtps2ph       xmm13, xmm1, 4
        vmovd   xmm4, ecx
        movzx   ecx, word ptr [rsp + 176]
        vcvtph2ps       xmm1, xmm17
        vcvtph2ps       xmm0, xmm16
        vmovd   xmm7, ecx
        movzx   ecx, word ptr [rsp + 48]
        vmulss  xmm0, xmm0, xmm1
        vcvtps2ph       xmm0, xmm0, 4
        vcvtph2ps       xmm0, xmm0
        vcvtph2ps       xmm1, xmm18
        vaddss  xmm0, xmm0, xmm1
        vcvtps2ph       xmm14, xmm0, 4
        vmovd   xmm1, ecx
        movzx   ecx, word ptr [rsp + 304]
        vcvtph2ps       xmm0, xmm19
        vcvtph2ps       xmm2, xmm2
        vmovd   xmm6, ecx
        movzx   ecx, word ptr [rsp + 184]
        vmulss  xmm0, xmm2, xmm0
        vcvtps2ph       xmm0, xmm0, 4
        vcvtph2ps       xmm0, xmm0
        vcvtph2ps       xmm2, xmm20
        vaddss  xmm0, xmm0, xmm2
        vcvtps2ph       xmm16, xmm0, 4
        vmovd   xmm2, ecx
        movzx   ecx, word ptr [rsp + 56]
        vcvtph2ps       xmm5, xmm5
        vcvtph2ps       xmm3, xmm3
        vmovd   xmm0, ecx
        movzx   ecx, word ptr [rsp + 312]
        vmulss  xmm3, xmm3, xmm5
        vcvtps2ph       xmm3, xmm3, 4
        vcvtph2ps       xmm3, xmm3
        vcvtph2ps       xmm4, xmm4
        vaddss  xmm3, xmm3, xmm4
        vcvtps2ph       xmm3, xmm3, 4
        vmovd   xmm4, ecx
        movzx   ecx, word ptr [rsp + 192]
        vcvtph2ps       xmm5, xmm7
        vcvtph2ps       xmm1, xmm1
        vmovd   xmm7, ecx
        movzx   ecx, word ptr [rsp + 64]
        vmulss  xmm1, xmm1, xmm5
        vcvtps2ph       xmm1, xmm1, 4
        vcvtph2ps       xmm1, xmm1
        vcvtph2ps       xmm5, xmm6
        vaddss  xmm1, xmm1, xmm5
        vcvtps2ph       xmm1, xmm1, 4
        vmovd   xmm5, ecx
        movzx   ecx, word ptr [rsp + 320]
        vcvtph2ps       xmm2, xmm2
        vcvtph2ps       xmm0, xmm0
        vmovd   xmm6, ecx
        movzx   ecx, word ptr [rsp + 200]
        vmulss  xmm0, xmm0, xmm2
        vcvtps2ph       xmm0, xmm0, 4
        vcvtph2ps       xmm0, xmm0
        vcvtph2ps       xmm2, xmm4
        vaddss  xmm0, xmm0, xmm2
        vcvtps2ph       xmm0, xmm0, 4
        vmovd   xmm2, ecx
        movzx   ecx, word ptr [rsp + 72]
        vcvtph2ps       xmm4, xmm7
        vcvtph2ps       xmm5, xmm5
        vmovd   xmm7, ecx
        movzx   ecx, word ptr [rsp + 328]
        vmulss  xmm4, xmm5, xmm4
        vcvtps2ph       xmm4, xmm4, 4
        vcvtph2ps       xmm4, xmm4
        vcvtph2ps       xmm5, xmm6
        vaddss  xmm4, xmm4, xmm5
        vcvtps2ph       xmm4, xmm4, 4
        vmovd   xmm5, ecx
        movzx   ecx, word ptr [rsp + 208]
        vcvtph2ps       xmm2, xmm2
        vcvtph2ps       xmm6, xmm7
        vmovd   xmm7, ecx
        movzx   ecx, word ptr [rsp + 80]
        vmulss  xmm2, xmm6, xmm2
        vcvtps2ph       xmm2, xmm2, 4
        vcvtph2ps       xmm2, xmm2
        vcvtph2ps       xmm5, xmm5
        vaddss  xmm2, xmm2, xmm5
        vcvtps2ph       xmm2, xmm2, 4
        vmovd   xmm5, ecx
        movzx   ecx, word ptr [rsp + 336]
        vcvtph2ps       xmm6, xmm7
        vcvtph2ps       xmm5, xmm5
        vmovd   xmm7, ecx
        movzx   ecx, word ptr [rsp + 216]
        vmulss  xmm5, xmm5, xmm6
        vcvtps2ph       xmm5, xmm5, 4
        vcvtph2ps       xmm5, xmm5
        vcvtph2ps       xmm6, xmm7
        vaddss  xmm5, xmm5, xmm6
        vcvtps2ph       xmm5, xmm5, 4
        vmovd   xmm6, ecx
        movzx   ecx, word ptr [rsp + 88]
        vcvtph2ps       xmm6, xmm6
        vmovd   xmm7, ecx
        vcvtph2ps       xmm7, xmm7
        movzx   ecx, word ptr [rsp + 344]
        vmulss  xmm6, xmm7, xmm6
        vcvtps2ph       xmm6, xmm6, 4
        vcvtph2ps       xmm6, xmm6
        vmovd   xmm7, ecx
        vcvtph2ps       xmm7, xmm7
        vaddss  xmm6, xmm6, xmm7
        vcvtps2ph       xmm6, xmm6, 4
        vpextrw word ptr [rax + 30], xmm6, 0
        vpextrw word ptr [rax + 28], xmm5, 0
        vpextrw word ptr [rax + 26], xmm2, 0
        vpextrw word ptr [rax + 24], xmm4, 0
        vpextrw word ptr [rax + 22], xmm0, 0
        vpextrw word ptr [rax + 20], xmm1, 0
        vpextrw word ptr [rax + 18], xmm3, 0
        vpextrw word ptr [rax + 16], xmm16, 0
        vpextrw word ptr [rax + 14], xmm14, 0
        vpextrw word ptr [rax + 12], xmm13, 0
        vpextrw word ptr [rax + 10], xmm15, 0
        vpextrw word ptr [rax + 8], xmm12, 0
        vpextrw word ptr [rax + 6], xmm11, 0
        vpextrw word ptr [rax + 4], xmm10, 0
        vpextrw word ptr [rax + 2], xmm9, 0
        vpextrw word ptr [rax], xmm8, 0
        ret
        nop     word ptr cs:[rax + rax]

LLVM seems very eager to fall back on these scalarized fallbacks. For example, masked load through casting from Float16 to Int16 and back:

        .text
        mov     rax, rdi
        mov     rsi, qword ptr [rsi]
        kmovd   k1, ecx
        vmovdqu16       ymm0 {k1} {z}, ymmword ptr [rsi + 2*rdx - 2]
        vmovdqa ymmword ptr [rdi], ymm0
        vzeroupper
        ret
        nop     dword ptr [rax]

But without manually adding these casts:

        .text
        push    rbp
        push    r15
        push    r14
        push    r13
        push    r12
        push    rbx
        mov     rax, rdi
        mov     rdx, qword ptr [rdx]
        mov     rsi, qword ptr [rsi]
        lea     rbp, [rsi + 2*rdx]
        add     rbp, -2
        movzx   r11d, word ptr [rcx]
        test    r11b, 1
        je      L113
        movzx   ecx, word ptr [rbp]
        mov     dword ptr [rsp - 4], ecx
        xor     esi, esi
        mov     dword ptr [rsp - 8], 0
        mov     ecx, esi
        test    r11b, 2
        jne     L139
L63:
        mov     word ptr [rsp - 18], si
        mov     edi, esi
        mov     word ptr [rsp - 16], si
        mov     edx, esi
        mov     ebx, esi
        mov     r12d, esi
        mov     r8d, esi
        mov     r9d, esi
        mov     r15d, esi
        mov     r14d, esi
        mov     r13d, esi
        mov     r10d, esi
        mov     word ptr [rsp - 14], si
        test    r11b, 4
        je      L204
        jmp     L191
L113:
        mov     dword ptr [rsp - 4], 0
        xor     esi, esi
        mov     dword ptr [rsp - 8], 0
        mov     ecx, esi
        test    r11b, 2
        je      L63
L139:
        mov     word ptr [rsp - 18], si
        mov     edi, esi
        mov     word ptr [rsp - 16], si
        mov     edx, esi
        mov     ebx, esi
        mov     r12d, esi
        mov     r8d, esi
        mov     r9d, esi
        mov     r15d, esi
        mov     r14d, esi
        mov     r13d, esi
        mov     r10d, esi
        mov     word ptr [rsp - 14], si
        movzx   esi, word ptr [rbp + 2]
        test    r11b, 4
        je      L204
L191:
        mov     ecx, edi
        movzx   edi, word ptr [rbp + 4]
        mov     word ptr [rsp - 18], di
        mov     edi, ecx
L204:
        test    r11b, 8
        jne     L368
        test    r11b, 16
        jne     L382
L224:
        test    r11b, 32
        jne     L401
L234:
        test    r11b, 64
        jne     L415
L244:
        test    r11b, -128
        jne     L429
L254:
        test    r11d, 256
        jne     L447
L267:
        test    r11d, 512
        jne     L465
L280:
        test    r11d, 1024
        jne     L483
L293:
        test    r11d, 2048
        jne     L501
L306:
        test    r11d, 4096
        jne     L519
L319:
        test    r11d, 8192
        mov     word ptr [rsp - 10], r13w
        je      L543
L338:
        movzx   ecx, word ptr [rbp + 26]
        mov     word ptr [rsp - 12], cx
        mov     r10d, r14d
        test    r11d, 16384
        je      L570
        jmp     L561
L368:
        movzx   edi, word ptr [rbp + 6]
        test    r11b, 16
        je      L224
L382:
        movzx   ecx, word ptr [rbp + 8]
        mov     word ptr [rsp - 16], cx
        test    r11b, 32
        je      L234
L401:
        movzx   edx, word ptr [rbp + 10]
        test    r11b, 64
        je      L244
L415:
        movzx   ebx, word ptr [rbp + 12]
        test    r11b, -128
        je      L254
L429:
        movzx   r12d, word ptr [rbp + 14]
        test    r11d, 256
        je      L267
L447:
        movzx   r8d, word ptr [rbp + 16]
        test    r11d, 512
        je      L280
L465:
        movzx   r9d, word ptr [rbp + 18]
        test    r11d, 1024
        je      L293
L483:
        movzx   r15d, word ptr [rbp + 20]
        test    r11d, 2048
        je      L306
L501:
        movzx   r14d, word ptr [rbp + 22]
        test    r11d, 4096
        je      L319
L519:
        movzx   r13d, word ptr [rbp + 24]
        test    r11d, 8192
        mov     word ptr [rsp - 10], r13w
        jne     L338
L543:
        mov     word ptr [rsp - 12], r10w
        mov     r10d, r14d
        test    r11d, 16384
        je      L570
L561:
        movzx   ecx, word ptr [rbp + 28]
        mov     word ptr [rsp - 14], cx
L570:
        movzx   r14d, word ptr [rsp - 16]
        mov     r13d, edi
        test    r11d, 32768
        je      L599
        movzx   edi, word ptr [rsp - 18]
        movzx   ebp, word ptr [rbp + 30]
        jmp     L608
L599:
        movzx   edi, word ptr [rsp - 18]
        mov     ebp, dword ptr [rsp - 8]
L608:
        mov     ecx, dword ptr [rsp - 4]
        mov     word ptr [rax], cx
        mov     word ptr [rax + 2], si
        mov     word ptr [rax + 4], di
        mov     word ptr [rax + 6], r13w
        mov     word ptr [rax + 8], r14w
        mov     word ptr [rax + 10], dx
        mov     word ptr [rax + 12], bx
        mov     word ptr [rax + 14], r12w
        mov     word ptr [rax + 16], r8w
        mov     word ptr [rax + 18], r9w
        mov     word ptr [rax + 20], r15w
        mov     word ptr [rax + 22], r10w
        movzx   ecx, word ptr [rsp - 10]
        mov     word ptr [rax + 24], cx
        movzx   ecx, word ptr [rsp - 12]
        mov     word ptr [rax + 26], cx
        movzx   ecx, word ptr [rsp - 14]
        mov     word ptr [rax + 28], cx
        mov     word ptr [rax + 30], bp
        pop     rbx
        pop     r12
        pop     r13
        pop     r14
        pop     r15
        pop     rbp
        ret
        nop     word ptr cs:[rax + rax]

(yikes!)

After a few more changes, this is what I get:

julia> @btime logsumexp4D8($A);
  2.257 ms (900 allocations: 24.46 MiB)

julia> @btime logsumexp4D8($B);
  753.780 μs (901 allocations: 12.25 MiB)

julia> @btime logsumexp4D8($C);
  357.733 μs (888 allocations: 6.15 MiB)

julia> versioninfo()
Julia Version 1.8.0-DEV.282
Commit ab699b697a* (2021-07-27 19:59 UTC)
Platform Info:
  OS: Linux (x86_64-generic-linux)
  CPU: Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, cascadelake)
Environment:
  JULIA_NUM_THREADS = 36

Now Float16 is about twice as fast as Float32, which is what we’d expect for memory bound operations.
Code:

using Tullio, LoopVectorization
function logsumexp4D8(Arr4d::Array{<:Real, 4})
  @tullio (max) max_[i,j,k] := Arr4d[i,j,k,l]
  @tullio _[i,j,k] := exp(Arr4d[i,j,k,l] - max_[i,j,k]) |> log(_) + max_[i,j,k]
end

A = rand(20, 1000, 80, 5);
B = convert(Array{Float32}, A);
C = convert(Array{Float16}, A);

@btime logsumexp4D8($A);
@btime logsumexp4D8($B);
@btime logsumexp4D8($C);

Requires this Tullio PR and this VectorizationBase commit.

EDIT:
Should work using the latest versions of VectorizationBase and Tullio.

2 Likes

Hmm I think I updated correctly but I’m not seeing the full benefit:

julia> @btime logsumexp4D8($A);
  11.535 ms (233 allocations: 24.42 MiB)

julia> @btime logsumexp4D8($B);
  5.911 ms (233 allocations: 12.22 MiB)

julia> @btime logsumexp4D8($C);
  53.950 ms (232 allocations: 6.11 MiB)

Also for my knowledge - what is happening with the array declaration: <:Real ?

LoopVectorization might require AVX512BW to do well here for now. That instruction set provides efficient masks for bytes and words.
When using masks, I have it cast Float16 into Int16 (words), otherwise performance is really bad.

But that trick probably only works if you have AVX512BW. Without it, the smallest size you can mask is 32 bits (dwords / singles).

I was also surprised to find a massive impact resulting from the CPU governor.
I ran the results on a 7980XE and got poor performance for Float16 (my earlier run was on a different computer with a 10980XE, which is more or less the same CPU):

julia> @btime logsumexp4D8($A); # Float64
  3.274 ms (897 allocations: 24.46 MiB)

julia> @btime logsumexp4D8($B); # Float32
  1.098 ms (897 allocations: 12.25 MiB)

julia> @btime logsumexp4D8($C); # Float16
  1.707 ms (888 allocations: 6.15 MiB)

I reran the benchmarks on the same machine I used earlier, the 10980XE, to confirm my earlier results:

julia> @btime logsumexp4D8($A); # Float64
  2.269 ms (893 allocations: 24.46 MiB)

julia> @btime logsumexp4D8($B); # Float32
  746.328 μs (900 allocations: 12.25 MiB)

julia> @btime logsumexp4D8($C); # Float16
  364.798 μs (888 allocations: 6.15 MiB)

In trying to pin down the cause for the difference, I noticed that the clock speed on the 7980XE was very low when running the Float16 benchmark, merely 1.2 GHz!
I checked the CPU scaling governor (the following commands are Linux only, but Windows has means of controlling this as well):

cat /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor

and noticed the 7980XE was set to schedutil. Setting it to performance instead:

echo "performance" | sudo tee /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor

I now get:

julia> @btime logsumexp4D8($A);
  2.399 ms (900 allocations: 24.46 MiB)

julia> @btime logsumexp4D8($B);
  845.472 μs (904 allocations: 12.25 MiB)

julia> @btime logsumexp4D8($C);
  406.594 μs (888 allocations: 6.15 MiB)

Results more or less inline with what I got on the 10980XEs.

You could check your clock speeds as you run the benchmark and change your CPU’s performance profile, but I’m guessing this is an issue of certain operations being slow without AVX512BW.
If you dig through the assembly, you might be able to find problematic functions and come up with workarounds/performance fixes (like I did by casting loads/stores of Float16 to Int16).

1 Like

I don’t think my cpu supports this.

I’m on Mac I think its not so easy

You overestimate my skills :sweat_smile:

Those points aside, I updated LoopVectorization v0.12.60 ⇒ v0.12.61 and things improved:

julia> @btime logsumexp4D8($A);
  7.719 ms (233 allocations: 24.42 MiB)

julia> @btime logsumexp4D8($B);
  3.657 ms (233 allocations: 12.22 MiB)

julia> @btime logsumexp4D8($C);
  19.756 ms (232 allocations: 6.11 MiB)

@Elrod I’ve now updated Tullio v0.2.14 ⇒ v0.3.2 and these are the results:

julia> @btime logsumexp4D8($A);
  7.627 ms (233 allocations: 24.42 MiB)

julia> @btime logsumexp4D8($B);
  3.578 ms (233 allocations: 12.22 MiB)

julia> @btime logsumexp4D8($C);
  2.350 ms (232 allocations: 6.11 MiB)

Awesome!

2 Likes

Adding this comment in case it’s helpful to anyone. I was able to get significant speedups by using LoopVectorization. The code is specialized to the array shape & dimensions, but might be helpful to someone. See https://github.com/magerton/FastLogSumExp.jl/ and also https://github.com/JuliaSIMD/LoopVectorization.jl/issues/437

1 Like