Reduce(logaddexp,x) has arbitrarily better accuracy than logsumexp(x)

It occurred to me this morning that logsumexp is associative and commutative and can be calculated with a simple reduce. We could write

logsumexp(x) = reduce(logaddexp,x)

Isn’t that pretty!

So how does that compare to the really fancy implementation in LogExpFunctions.jl? The reduction is 5x slower, but surprisingly fast, equivalent to a different 1-pass algorithm than discovered by Sebastian Nowozin: Streaming Log-sum-exp Computation

using LogExpFunctions
using BenchmarkTools

x = rand(10^7)

@btime a1 = logsumexp(x)
  70.892 ms (1 allocation: 16 bytes)
16.659565226722435

@btime a2 = reduce(logaddexp, x)
 310.882 ms (1 allocation: 16 bytes)
16.659565226722393

But also: Interesting! a bit of numerical difference at 10^(-14)! Who is more accurate? I looked at the same algorithms at Float32:

y = Float32.(x)
b1 = logsumexp(y)
b2 = reduce(logaddexp, y)
b1-Float32(a1)
0.013198853f0
b2-Float32(a1)
0.0f0

The reduction is much more numerically stable! For larger vectors of rand the logsumexp gets arbitrarily bad, albeit at somewhat enormous vectors.

In a sense, the rand case is particularly bad because the maximum in this example had a lot of “multiplicity” many values in x near the maximum contribute a non-trivial amount to the final result. So subtracting off the maximum is not enough, it’s actually better to subtract off the larger of the running sum or the next number. That’s what the reduction does.

In fact, if you know a little physics, the log-multiplicity is exactly the entropy in the stat-mech sense (-Σ_i p_i log p_i), the effective number of states contributing to the free energy (logsumexp)

logp = x .- logsumexp(x)
logmplogp = logp .+ log.(-logp)
logS = logsumexp(logmplogp)
S = exp(logS)
16.0774514705695

quite near the logsumexp(x), so logsumexp(x) here is clearly “entropy driven” rather than “softmax driven”

Anyway, there are clearly some tradeoffs due to the vectorial nature of cpus so I’m not sure if this an issue. I wouldn’t neccesarily want to force a 5x slowdown on the most expensive calculation in many algorithms for accuracy that may not be important. Is it possible that the reduction version can be made faster with whatever magic is being used on logsumexp? Should I open an issue?

1 Like

Where is logaddexp defined?

Update - it’s here in LogExpFunctions.jl.

julia> using BenchmarkTools, LogExpFunctions

julia> x = rand(10^7)

julia> @btime a1 = logsumexp($x)
  36.006 ms (0 allocations: 0 bytes)
16.659561175877272

julia> @btime a2 = log(sum(exp,$x))
  28.252 ms (0 allocations: 0 bytes)
16.659561175877275

julia> f(x) = reduce(logaddexp,x)
f (generic function with 1 method)

julia> @btime a3 = f($x)
  240.178 ms (0 allocations: 0 bytes)
16.659561175877275

I think that this simple example must not illustrate the benefits of logsumexp since log(sum(exp,x)) is faster. Likely you need different inputs which might induce overflow/underflow without care.

x = rand(10^7) works with log(sum(exp.(x))) because x is confined to the nearly-linear part of exp(x), so there’s nothing to worry about.

But, say x = 1000 .+ rand(10^7). Then log(sum(exp.(x))) doesn’t work, but both reduce(logaddexp, x) and logsumexp(x) do. In the space of algorithms that DO work, reduce(logaddexp,x) seems to be more accurate for large vectors.

But is this the right way to check accuracy? Converting Float64 to Float32 changes the values and should lead to some difference. Isn’t it better to generate the random values as Float32 initially and then compare reduction in both precisions?

1 Like

sure, you’re right, it’s more intuitive to go the other way and promote and then check, and compare to eps:

y = rand(Float32,10^7)
x = Float64.(y)
a1 = logsumexp(x)
a2 = reduce(logaddexp, x)
b1 = logsumexp(y)
b2 = reduce(logaddexp, y)
(b1-Float32(a1))/eps(Float32)
111040.0f0
(b2-Float32(a1))/eps(Float32)
16.0f0

but it amounts to the same thing, since the precision of Float 64 is so much higher.

It seemed intuitive to me to check for accuracy by checking if the Float32 conversion “commutes” with the operation, but I don’t know what the standard method is as I’m not a numerics person.

To measure accuracy, the results can be compared against:

c1 = logsumexp(BigFloat.(x))

(slow for sure, but a good baseline for accuracy).

Another faster performance/accuracy tradeoff Pareto frontier sweet-spot is:

c2 = log(sum(exp, x; init=big"0.0"))

Another attempt at increasing accuracy, does so slightly with this x, but depends on the distribution of x values:

c3 = log(sum(expm1, x; init=big"0.0")+length(x))

This is susceptible to spurious overflow if the x elements can exceed 710.

1 Like
x = rand(10^9)
a1 = logsumexp(x)
a2 = reduce(logaddexp, x)
(a1-a2)/eps(a1)
  306.0
# is only around 25 for 10^8 and 16 for 10^7

I have to say the error at the Float64 level is a lot better behaved, presumably the numerics here are super well tested. I do think the reduce version really is more stable.

I wouldn’t expect better guarantees at this accuracy. But the situation really is weird at Float16 / Float32 level. I guess not the first time a FloatNot64 behaves weird?

I think that’s because this uses pairwise summation, whereas LogExpFunctions.jl uses a “one-pass” algorithm with an init argument, which ends up calling foldl (i.e. naive summation).

Although it should be possible to re-implement the one-pass algorithm in a pairwise fashion, it seems a lot easier just to use the two-pass algorithm, which only seems to be about 20% slower than the one-pass algorithm and benefits from pairwise summation:

function logsumexp_2pass(x::AbstractArray)
    isempty(x) && return log(sum(exp, x)) # -Inf
    max = maximum(x)
    return max + log(sum(x -> exp(x - max), x))
end

The fact that Julia’s mapreduce uses pairwise associativity in certain cases is a two-edged sword — it makes it more accurate without losing performance, but then people are continually surprised when other cases are less accurate.

1 Like

I see. I think you’re right and the entropy explanation I gave is bogus, this is probably pure pairwise enhancement.

I just wouldn’t condition on AbstractArray since I often give it generators. I also worry about summing say [0, 10^-16], you might run into places where log1p is better and logaddexp does that near the limit of what’s possible with Float64’s.

I just have a few more questions:

  • Should I open an issue with LogExpFunctions?
  • Any idea what is slowing down the reduce(logaddexp, x) here?
  • Does it make sense that the errors are so much worse on the one-pass for Float16 and Float32?

Couldn’t hurt.

You notice roundoff accumulation much more quickly with lower precisions.

Doing pairwise summation is a bit trickier when you don’t have random access: use pairwise order for mapreduce on arbitrary iterators by stevengj · Pull Request #52397 · JuliaLang/julia · GitHub

You could also use Kahan/compensated summation, maybe in blocks to improve performance.

1 Like