Fast fixed point log

I’m trying to compute Float32 log(x) where x = (u + 0.5) / 2³² and u is a integer from 0 to 2³² - 1
I started with:

@inline function fastlog1(u::UInt32)::Float32
    if u < UInt32(2)^31
        x = fma(Float32(u), Float32(2)^(-32), Float32(2)^(-33))
        log(x)
    else
        x = fma(Float32(~u), Float32(2)^(-32), Float64(2)^(-33))
        log1p(-x)
    end
end

This has the nice property of never returning 0 or -Inf

Sadly this has bad effects and doesn’t seem to SIMD:

julia> Base.infer_effects(fastlog1, (UInt32,))
(+c,+e,!n,+t,+s,+m,+u,+o,+r)

Claude was able to generate the following code, does this make sense and are there any obvious ways to make this even faster?

const _SQRT_HALF_I32 = reinterpret(Int32, Float32(sqrt(0.5)))

@inline function fastlog2(u::UInt32)::Float32
    x = fma(Float32(u), Float32(2)^(-32), Float32(2)^(-33))
    ix = reinterpret(Int32, x) - _SQRT_HALF_I32
    k = ix >> Int32(23)
    f_std = reinterpret(Float32, (ix & Int32(0x007fffff)) + _SQRT_HALF_I32) - 1.0f0
    f_comp = -fma(Float32(~u), Float32(2)^(-32), Float32(2)^(-33))
    f = ifelse(k == Int32(0), f_comp, f_std)
    s = f / (2.0f0 + f)
    z = s * s; w = z * z
    R = z * (reinterpret(Float32, Int32(0x3f2aaaaa)) +       # ≈ 2/3
             w * reinterpret(Float32, Int32(0x3e91e9ee))) +   # ≈ 2/7
        w * (reinterpret(Float32, Int32(0x3eccce13)) +        # ≈ 2/5
             w * reinterpret(Float32, Int32(0x3e789e26)))      # ≈ 2/9
    hfsq = 0.5f0 * f * f
    Float32(k) * reinterpret(Float32, Int32(0x3f317180)) -
        ((hfsq - (s * (hfsq + R) +
          Float32(k) * reinterpret(Float32, Int32(0x3717f7d1)))) - f)
end
julia> Base.infer_effects(fastlog2, (UInt32,))
(+c,+e,+n,+t,+s,+m,+u,+o,+r)

Here is a comparison with a reference Float64 version:

function fastlogref(u::UInt32)::Float32
    if u < UInt32(2)^31
        x = fma(Float64(u), Float64(2)^(-32), Float64(2)^(-33))
        log(x)
    else
        x = fma(Float64(~u), Float64(2)^(-32), Float64(2)^(-33))
        log1p(-x)
    end
end
julia> maximum(i -> abs(fastlogref(i) - fastlog2(i)), typemin(UInt32):typemax(UInt32))
1.9073486f-6

julia> using Chairmarks

julia> @b rand(UInt32, 1000_000) sum(fastlog1, _)
7.215 ms

julia> @b rand(UInt32, 1000_000) sum(fastlog2, _)
198.149 μs

I think the first question to ask is whether these two codes compute the same thing. Since the input is a UInt32, you can loop through them and check all of them, it shouldn’t take too long.

julia> for x in typemin(UInt32):typemax(UInt32)
           if !isapprox(fastlog1(x), fastlog2(x); rtol=2eps(Float32))
               error("fastlog1 and fastlog2 differ for x = $(x): fastlog1(x) = $(fastlog1(x)), fastlog2(x) = $(fastlog2(x))")
           end
       end

comes out clear (there are some points which differ by 1 or 2 ULPs, but not more than that, hence why I used rtol=2eps(Float32)). This takes 1 minute to run on my laptop.

Edit: I just realised you had a comparison already, but that was using absolute difference, not relative one, that doesn’t tell you much about how significant the difference is.

Is the Float64 there on purpose?

No that was a typo. Fixing that makes the original version slightly faster (about 6 ms for the sum benchmark).

Following up on this, here is a cleaner version using evalpoly and fma and a reference to where the polynomial coefficients are coming from: openlibm/src/e_logf.c at v0.8.7 · JuliaMath/openlibm · GitHub.

# Core log algorithm (polynomial coefficients, ln2 splitting, and reconstruction)
# adapted from fdlibm's e_log.c / e_logf.c (Sun Microsystems, 1993).
# See: https://github.com/JuliaMath/openlibm/blob/v0.8.7/src/e_log.c
#      https://github.com/JuliaMath/openlibm/blob/v0.8.7/src/e_logf.c

const _SQRT_HALF_I32 = reinterpret(Int32, Float32(sqrt(0.5)))
const _LOG_POLY_F32 = (0.6666666f0, 0.40000972f0, 0.28498787f0, 0.24279079f0)
const _LN2_HI_F32 = 0.6931381f0
const _LN2_LO_F32 = 9.058001f-6

@inline function _fast_log(::Type{Float32}, u::Union{UInt32, UInt64})
   x = u01(Float32, u)

   # Goal: find k and f such that
   # x = 2^k * (1+f)
   # where sqrt(2)/2 ≤ (1+f) < sqrt(2)
   # if k is zero
   # we calculate f by -u01(Float32, ~u) which is more accurate for x near 1

   # Float32 has 23 fractional bits.
   # x is ordered by value in Int32 space.
   # Starting from x=1, k starts at 0, then ix becomes negative at x = prevfloat(sqrt(0.5f0))
   # making k = -1. For each power of 2 scale in x,
   # k changes by one, because we shift out the 23 fraction bits.
   ix = reinterpret(Int32, x) - _SQRT_HALF_I32
   k = ix >> Int32(23)

   # `f_plus_one_std` will have the same fraction bits as `x`
   # because `- _SQRT_HALF_I32` and `+ _SQRT_HALF_I32` cancel out in the low 23 bits.
   # `& Int32(0x007fffff)` clears the exponent and sign fields.
   # `f_plus_one_std` must either have an exponent of -1 or 0.
   # If x's fractional bits are less than the fractional bits of _SQRT_HALF_I32
   # the `- _SQRT_HALF_I32` borrows a 2^23 from the exponent field of x,
   # which then shows up as an extra 2^23 in the low 23 bits after masking.
   # When adding _SQRT_HALF_I32 this extra 2^23 propagates up and
   # bumps the exponent from -1 to 0.
   f_plus_one_std = reinterpret(Float32, (ix & Int32(0x007fffff)) + _SQRT_HALF_I32)
   f_std = f_plus_one_std - 1.0f0

   f_comp = -u01(Float32, ~u)
   f = ifelse(k == Int32(0), f_comp, f_std)

   # Goal: get log(1+f) via a polynomial approx.
   # Let s = f/(2+f), z = s², and log_poly(z) ≈ evalpoly(z, _LOG_POLY_F32)
   # log(1+f) = 2s + s³*log_poly(s²)
   # R = s²*log_poly(s²)
   # log(1+f) = f - f²/2 + s*(f²/2 + R)
   s = f / (2.0f0 + f)
   z = s * s
   R = z * evalpoly(z, _LOG_POLY_F32)
   hfsq = 0.5f0 * f * f

   # log(x) = k*log(2) + log(1+f)
   k_f32 = Float32(k)
   # Simpler version, but fails the mean test by 2E-9
   # fma(k_f32, 0.6931472f0 #= log(2) =#, fma(s, R-f, f))
   # log(2) = _LN2_HI_F32 + _LN2_LO_F32
   fma(k_f32, _LN2_HI_F32,
       f - (hfsq - fma(s, (hfsq + R), k_f32 * _LN2_LO_F32))
   )
end

@inline function u01(::Type{F}, u::UInt32)::F where F
   fma(F(u), F(2)^Int32(-32), F(2)^Int32(-33))
end

@inline function u01(::Type{F}, u::UInt64)::F where F
   fma(F(u), F(2)^Int32(-64), F(2)^Int32(-65))
end