Speeding up my logsumexp function

Here are a few faster variants, using a package of mine and one of @Elrod’s. If you want this to work with Zygote, then the time spent working out the gradient will tend to dominate.

function lsexp_mat1(mat; dims=1)
    max_ = maximum(mat, dims=1)
    zero1_mat = safeeq(mat, max_) # working around a Zygote bug, today?
    exp_mat = exp.(mat .- max_)
    sum_exp_ = sum(exp_mat .- zero1_mat, dims=dims)
    log1p.(sum_exp_) .+ max_
end
safeeq(mat, max_) = (mat .== max_)

function lsexp_mat2(mat; dims=1) # less memory but not really faster?
    max_ = maximum(mat, dims=1)
    exp_mat = exp.(mat .- max_) .- (mat .== max_) # fuse this broadcast, @Oscar_Smith beat me to it!
    sum_exp_ = sum(exp_mat, dims=dims)
    sum_exp_ .= log1p.(sum_exp_) .+ max_ # re-use this array?
end

using Tullio # ] add Tullio#master -- I just fixed a bug about == & gradients

function lsexp_mat3(mat) # not generic over dims, but differentiable
    max_ = maximum(mat, dims=1)
    @tullio exp_mat[i,j] := exp(mat[i,j] - max_[1,j]) - (mat[i,j] == max_[1,j]) avx=false # grad=Dual # fixed on master
    sum_exp_ = sum(exp_mat, dims=1)
    @tullio out[i,j] := log1p(sum_exp_[i,j]) + max_[i,j] avx=false
end

using LoopVectorization

# function lsexp_mat4(mat; dims=1) # @avx broadcasting, is having a bad day
#     max_ = maximum(mat, dims=1)
#     # zero1_mat = (mat .== max_)
#     exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) # has lots of NaN & Inf in it?
#     sum_exp_ = sum(exp_mat, dims=dims)
#     @avx sum_exp_ .= log1p.(sum_exp_) .+ max_ # mostly NaN
# end

function lsexp_mat5(mat) # also using @avx
    max_ = maximum(mat, dims=1)
    @tullio exp_mat[i,j] := exp(mat[i,j] - max_[1,j]) - (mat[i,j] == max_[1,j])
    sum_exp_ = sum(exp_mat, dims=1)
    @tullio out[i,j] := log1p(sum_exp_[i,j]) + max_[i,j]
end

n = 1_000; A = rand(n,n);
lsexp_mat(A) ≈ lsexp_mat1(A) ≈ lsexp_mat2(A)
# lsexp_mat(A) ≈ lsexp_mat4(A) # false? 
lsexp_mat(A) ≈ lsexp_mat3(A) ≈ lsexp_mat5(A)

@btime lsexp_mat($A)  # 10.164 ms (13 allocations: 15.41 MiB)
@btime lsexp_mat2($A) # 10.631 ms (6 allocations: 7.65 MiB)
@btime lsexp_mat3($A) #  3.188 ms (78 allocations: 7.66 MiB)
# @btime lsexp_mat4($A) #  3.494 ms (14 allocations: 7.65 MiB)
@btime lsexp_mat5($A) #  2.069 ms (76 allocations: 7.66 MiB)

using Tracker, Zygote #, ForwardDiff
Zygote.@nograd safeeq

gA = Tracker.gradient(sum∘lsexp_mat, A)[1];
Zygote.gradient(sum∘lsexp_mat1, A)[1] ≈ gA
Zygote.gradient(sum∘lsexp_mat3, A)[1] ≈ gA
Zygote.gradient(sum∘lsexp_mat5, A)[1] ≈ gA

@btime Zygote.gradient(sum∘lsexp_mat1, $A); # 69.199 ms (3003130 allocations: 137.61 MiB)
@btime Zygote.gradient(sum∘lsexp_mat3, $A); # 12.547 ms (253 allocations: 38.23 MiB)
@btime Zygote.gradient(sum∘lsexp_mat5, $A); # 10.309 ms (248 allocations: 38.23 MiB)
3 Likes