Fast fixed point log

I’m trying to compute Float32 log(x) where x = (u + 0.5) / 2³² and u is a integer from 0 to 2³² - 1
I started with:

@inline function fastlog1(u::UInt32)::Float32
    if u < UInt32(2)^31
        x = fma(Float32(u), Float32(2)^(-32), Float32(2)^(-33))
        log(x)
    else
        x = fma(Float32(~u), Float32(2)^(-32), Float64(2)^(-33))
        log1p(-x)
    end
end

This has the nice property of never returning 0 or -Inf

Sadly this has bad effects and doesn’t seem to SIMD:

julia> Base.infer_effects(fastlog1, (UInt32,))
(+c,+e,!n,+t,+s,+m,+u,+o,+r)

Claude was able to generate the following code, does this make sense and are there any obvious ways to make this even faster?

const _SQRT_HALF_I32 = reinterpret(Int32, Float32(sqrt(0.5)))

@inline function fastlog2(u::UInt32)::Float32
    x = fma(Float32(u), Float32(2)^(-32), Float32(2)^(-33))
    ix = reinterpret(Int32, x) - _SQRT_HALF_I32
    k = ix >> Int32(23)
    f_std = reinterpret(Float32, (ix & Int32(0x007fffff)) + _SQRT_HALF_I32) - 1.0f0
    f_comp = -fma(Float32(~u), Float32(2)^(-32), Float32(2)^(-33))
    f = ifelse(k == Int32(0), f_comp, f_std)
    s = f / (2.0f0 + f)
    z = s * s; w = z * z
    R = z * (reinterpret(Float32, Int32(0x3f2aaaaa)) +       # ≈ 2/3
             w * reinterpret(Float32, Int32(0x3e91e9ee))) +   # ≈ 2/7
        w * (reinterpret(Float32, Int32(0x3eccce13)) +        # ≈ 2/5
             w * reinterpret(Float32, Int32(0x3e789e26)))      # ≈ 2/9
    hfsq = 0.5f0 * f * f
    Float32(k) * reinterpret(Float32, Int32(0x3f317180)) -
        ((hfsq - (s * (hfsq + R) +
          Float32(k) * reinterpret(Float32, Int32(0x3717f7d1)))) - f)
end
julia> Base.infer_effects(fastlog2, (UInt32,))
(+c,+e,+n,+t,+s,+m,+u,+o,+r)

Here is a comparison with a reference Float64 version:

function fastlogref(u::UInt32)::Float32
    if u < UInt32(2)^31
        x = fma(Float64(u), Float64(2)^(-32), Float64(2)^(-33))
        log(x)
    else
        x = fma(Float64(~u), Float64(2)^(-32), Float64(2)^(-33))
        log1p(-x)
    end
end
julia> maximum(i -> abs(fastlogref(i) - fastlog2(i)), typemin(UInt32):typemax(UInt32))
1.9073486f-6

julia> using Chairmarks

julia> @b rand(UInt32, 1000_000) sum(fastlog1, _)
7.215 ms

julia> @b rand(UInt32, 1000_000) sum(fastlog2, _)
198.149 μs

I think the first question to ask is whether these two codes compute the same thing. Since the input is a UInt32, you can loop through them and check all of them, it shouldn’t take too long.

julia> for x in typemin(UInt32):typemax(UInt32)
           if !isapprox(fastlog1(x), fastlog2(x); rtol=2eps(Float32))
               error("fastlog1 and fastlog2 differ for x = $(x): fastlog1(x) = $(fastlog1(x)), fastlog2(x) = $(fastlog2(x))")
           end
       end

comes out clear (there are some points which differ by 1 or 2 ULPs, but not more than that, hence why I used rtol=2eps(Float32)). This takes 1 minute to run on my laptop.

Edit: I just realised you had a comparison already, but that was using absolute difference, not relative one, that doesn’t tell you much about how significant the difference is.

Is the Float64 there on purpose?

2 Likes

No that was a typo. Fixing that makes the original version slightly faster (about 6 ms for the sum benchmark).

2 Likes

Following up on this, here is a cleaner version using evalpoly and fma and a reference to where the polynomial coefficients are coming from: openlibm/src/e_logf.c at v0.8.7 · JuliaMath/openlibm · GitHub.

# Core log algorithm (polynomial coefficients, ln2 splitting, and reconstruction)
# adapted from fdlibm's e_log.c / e_logf.c (Sun Microsystems, 1993).
# See: https://github.com/JuliaMath/openlibm/blob/v0.8.7/src/e_log.c
#      https://github.com/JuliaMath/openlibm/blob/v0.8.7/src/e_logf.c

const _SQRT_HALF_I32 = reinterpret(Int32, Float32(sqrt(0.5)))
const _LOG_POLY_F32 = (0.6666666f0, 0.40000972f0, 0.28498787f0, 0.24279079f0)
const _LN2_HI_F32 = 0.6931381f0
const _LN2_LO_F32 = 9.058001f-6

@inline function _fast_log(::Type{Float32}, u::Union{UInt32, UInt64})
   x = u01(Float32, u)

   # Goal: find k and f such that
   # x = 2^k * (1+f)
   # where sqrt(2)/2 ≤ (1+f) < sqrt(2)
   # if k is zero
   # we calculate f by -u01(Float32, ~u) which is more accurate for x near 1

   # Float32 has 23 fractional bits.
   # x is ordered by value in Int32 space.
   # Starting from x=1, k starts at 0, then ix becomes negative at x = prevfloat(sqrt(0.5f0))
   # making k = -1. For each power of 2 scale in x,
   # k changes by one, because we shift out the 23 fraction bits.
   ix = reinterpret(Int32, x) - _SQRT_HALF_I32
   k = ix >> Int32(23)

   # `f_plus_one_std` will have the same fraction bits as `x`
   # because `- _SQRT_HALF_I32` and `+ _SQRT_HALF_I32` cancel out in the low 23 bits.
   # `& Int32(0x007fffff)` clears the exponent and sign fields.
   # `f_plus_one_std` must either have an exponent of -1 or 0.
   # If x's fractional bits are less than the fractional bits of _SQRT_HALF_I32
   # the `- _SQRT_HALF_I32` borrows a 2^23 from the exponent field of x,
   # which then shows up as an extra 2^23 in the low 23 bits after masking.
   # When adding _SQRT_HALF_I32 this extra 2^23 propagates up and
   # bumps the exponent from -1 to 0.
   f_plus_one_std = reinterpret(Float32, (ix & Int32(0x007fffff)) + _SQRT_HALF_I32)
   f_std = f_plus_one_std - 1.0f0

   f_comp = -u01(Float32, ~u)
   f = ifelse(k == Int32(0), f_comp, f_std)

   # Goal: get log(1+f) via a polynomial approx.
   # Let s = f/(2+f), z = s², and log_poly(z) ≈ evalpoly(z, _LOG_POLY_F32)
   # log(1+f) = 2s + s³*log_poly(s²)
   # R = s²*log_poly(s²)
   # log(1+f) = f - f²/2 + s*(f²/2 + R)
   s = f / (2.0f0 + f)
   z = s * s
   R = z * evalpoly(z, _LOG_POLY_F32)
   hfsq = 0.5f0 * f * f

   # log(x) = k*log(2) + log(1+f)
   k_f32 = Float32(k)
   # Simpler version, but fails the mean test by 2E-9
   # fma(k_f32, 0.6931472f0 #= log(2) =#, fma(s, R-f, f))
   # log(2) = _LN2_HI_F32 + _LN2_LO_F32
   fma(k_f32, _LN2_HI_F32,
       f - (hfsq - fma(s, (hfsq + R), k_f32 * _LN2_LO_F32))
   )
end

@inline function u01(::Type{F}, u::UInt32)::F where F
   fma(F(u), F(2)^Int32(-32), F(2)^Int32(-33))
end

@inline function u01(::Type{F}, u::UInt64)::F where F
   fma(F(u), F(2)^Int32(-64), F(2)^Int32(-65))
end