Here are a few faster variants, using a package of mine and one of @Elrod’s. If you want this to work with Zygote, then the time spent working out the gradient will tend to dominate.
function lsexp_mat1(mat; dims=1)
max_ = maximum(mat, dims=1)
zero1_mat = safeeq(mat, max_) # working around a Zygote bug, today?
exp_mat = exp.(mat .- max_)
sum_exp_ = sum(exp_mat .- zero1_mat, dims=dims)
log1p.(sum_exp_) .+ max_
end
safeeq(mat, max_) = (mat .== max_)
function lsexp_mat2(mat; dims=1) # less memory but not really faster?
max_ = maximum(mat, dims=1)
exp_mat = exp.(mat .- max_) .- (mat .== max_) # fuse this broadcast, @Oscar_Smith beat me to it!
sum_exp_ = sum(exp_mat, dims=dims)
sum_exp_ .= log1p.(sum_exp_) .+ max_ # re-use this array?
end
using Tullio # ] add Tullio#master -- I just fixed a bug about == & gradients
function lsexp_mat3(mat) # not generic over dims, but differentiable
max_ = maximum(mat, dims=1)
@tullio exp_mat[i,j] := exp(mat[i,j] - max_[1,j]) - (mat[i,j] == max_[1,j]) avx=false # grad=Dual # fixed on master
sum_exp_ = sum(exp_mat, dims=1)
@tullio out[i,j] := log1p(sum_exp_[i,j]) + max_[i,j] avx=false
end
using LoopVectorization
# function lsexp_mat4(mat; dims=1) # @avx broadcasting, is having a bad day
# max_ = maximum(mat, dims=1)
# # zero1_mat = (mat .== max_)
# exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) # has lots of NaN & Inf in it?
# sum_exp_ = sum(exp_mat, dims=dims)
# @avx sum_exp_ .= log1p.(sum_exp_) .+ max_ # mostly NaN
# end
function lsexp_mat5(mat) # also using @avx
max_ = maximum(mat, dims=1)
@tullio exp_mat[i,j] := exp(mat[i,j] - max_[1,j]) - (mat[i,j] == max_[1,j])
sum_exp_ = sum(exp_mat, dims=1)
@tullio out[i,j] := log1p(sum_exp_[i,j]) + max_[i,j]
end
n = 1_000; A = rand(n,n);
lsexp_mat(A) ≈ lsexp_mat1(A) ≈ lsexp_mat2(A)
# lsexp_mat(A) ≈ lsexp_mat4(A) # false?
lsexp_mat(A) ≈ lsexp_mat3(A) ≈ lsexp_mat5(A)
@btime lsexp_mat($A) # 10.164 ms (13 allocations: 15.41 MiB)
@btime lsexp_mat2($A) # 10.631 ms (6 allocations: 7.65 MiB)
@btime lsexp_mat3($A) # 3.188 ms (78 allocations: 7.66 MiB)
# @btime lsexp_mat4($A) # 3.494 ms (14 allocations: 7.65 MiB)
@btime lsexp_mat5($A) # 2.069 ms (76 allocations: 7.66 MiB)
using Tracker, Zygote #, ForwardDiff
Zygote.@nograd safeeq
gA = Tracker.gradient(sum∘lsexp_mat, A)[1];
Zygote.gradient(sum∘lsexp_mat1, A)[1] ≈ gA
Zygote.gradient(sum∘lsexp_mat3, A)[1] ≈ gA
Zygote.gradient(sum∘lsexp_mat5, A)[1] ≈ gA
@btime Zygote.gradient(sum∘lsexp_mat1, $A); # 69.199 ms (3003130 allocations: 137.61 MiB)
@btime Zygote.gradient(sum∘lsexp_mat3, $A); # 12.547 ms (253 allocations: 38.23 MiB)
@btime Zygote.gradient(sum∘lsexp_mat5, $A); # 10.309 ms (248 allocations: 38.23 MiB)