It occurred to me this morning that logsumexp
is associative and commutative and can be calculated with a simple reduce. We could write
logsumexp(x) = reduce(logaddexp,x)
Isn’t that pretty!
So how does that compare to the really fancy implementation in LogExpFunctions.jl? The reduction is 5x slower, but surprisingly fast, equivalent to a different 1-pass algorithm than discovered by Sebastian Nowozin: Streaming Log-sum-exp Computation
using LogExpFunctions
using BenchmarkTools
x = rand(10^7)
@btime a1 = logsumexp(x)
70.892 ms (1 allocation: 16 bytes)
16.659565226722435
@btime a2 = reduce(logaddexp, x)
310.882 ms (1 allocation: 16 bytes)
16.659565226722393
But also: Interesting! a bit of numerical difference at 10^(-14)! Who is more accurate? I looked at the same algorithms at Float32:
y = Float32.(x)
b1 = logsumexp(y)
b2 = reduce(logaddexp, y)
b1-Float32(a1)
0.013198853f0
b2-Float32(a1)
0.0f0
The reduction is much more numerically stable! For larger vectors of rand the logsumexp gets arbitrarily bad, albeit at somewhat enormous vectors.
In a sense, the rand case is particularly bad because the maximum in this example had a lot of “multiplicity” many values in x
near the maximum contribute a non-trivial amount to the final result. So subtracting off the maximum is not enough, it’s actually better to subtract off the larger of the running sum or the next number. That’s what the reduction does.
In fact, if you know a little physics, the log-multiplicity is exactly the entropy in the stat-mech sense (-Σ_i p_i log p_i), the effective number of states contributing to the free energy (logsumexp)
logp = x .- logsumexp(x)
logmplogp = logp .+ log.(-logp)
logS = logsumexp(logmplogp)
S = exp(logS)
16.0774514705695
quite near the logsumexp(x), so logsumexp(x) here is clearly “entropy driven” rather than “softmax driven”
Anyway, there are clearly some tradeoffs due to the vectorial nature of cpus so I’m not sure if this an issue. I wouldn’t neccesarily want to force a 5x slowdown on the most expensive calculation in many algorithms for accuracy that may not be important. Is it possible that the reduction version can be made faster with whatever magic is being used on logsumexp? Should I open an issue?