Modular multiplication without overflow

I was going to say that perhaps I’m being dense but doesn’t this definition work well for all bit sizes less that 128:

mulmod(a::Integer, b::Integer, c::Integer) =
    mod(widemul(a, b), c) % typeof(c)

The tricky case is how to do this for 128-bit arguments where c actually is too big to fit in a 64-bit integer.

Also note: regrettably we followed the C definition of % and it actually does rem whereas what everyone wants here is mod.

3 Likes

I agree that we should use native arithmetic (with widening) whenever possible. Although I’m not sure anyone is disputing that.

At the widest native types (eg Int128) I imagine it’s still worthwhile to try to complete the operation in native arithmetic. A run-time overflow check can slow-path to the algorithm discussed here when required, but the significant speed difference makes the “try native first” approach enticing. I don’t imagine that very large arguments are common in most domains.

1 Like

Very closely related: factorials and binomial coefficients! Both of those are useful to have modulo some other number, because both overflow very quickly.

Modular variants of those would certainly be useful to have! Not sure if they belong in base or a package though. Maybe ModularMethods.jl or something like that?

1 Like

The OP was specifically about UInt128.

2 Likes

Ah yeah, missed that!

I do think there should be a package specifically for modular arithmetic, but I think these two are basic enough they could go in base alongside the current modular power.

You say regrettably–are you thinking of changing it in Julia 2.0? I think it’d be nice for a lot of people coming from Python, at the very least.

Surprisingly, the more complex method I defined above is actually 4x faster on UInt64s than promotion to UInt128 via the widemul method in microbenchmarks. Corroborating this, replacing the widemul approach in Base’s powmod(::Int, ::Int, ::Int) method with my ugly version also results in a significant speedup in that slightly-less-micro benchmark.

This is because while 128-bit integers are native in Julia they are not supported natively on my hardware. YMMV.

3 Likes

It would certainly be on a list of changes to consider.

3 Likes

Here’s a version for UInt128 which runs in 100ns on my computer (vs 122ns for lilith’s version and 350ns for the wiki version). Looking back, I think it is roughly the same algorithm as lilith’s.

Profiling reveals that almost all the time is spent evenly in the 8 division operations (mod, fldmod (counts for 2) and sub_mod). I’m not sure if you can do it in fewer divisions.

Code
using BitIntegers: UInt256
using BenchmarkTools: @benchmark

function mul_full(x::UInt128, y::UInt128)
    z = UInt256(x) * UInt256(y)
    return (mod(z, UInt128), mod(z >> 128, UInt128))
end

function sub_mod(x::UInt128, y::UInt128, m::UInt128)
    if x > y
        return mod(x - y, m)
    else
        return m - mod1(y - x, m)
    end
end

function mul_mod(x::UInt128, y::UInt128, m::UInt128)
    # fast branch: no overflow
    if iszero(x >> 64) && iszero(y >> 64)
        return mod(x * y, m)
    end
    # fast branch: no overflow after reducing mod m
    if iszero(m >> 64)
        return mod(mod(x, m) * mod(y, m), m)
    end
    # m in base 2^64
    # m = m0 + m1 * u
    # where u = 2^64
    m0 = mod(mod(m, UInt64), UInt128)
    m1 = m >> 64
    # z=x*y in base 2^128
    # z == z0 + u^2 * z1
    z0, z1 = mul_full(x, y)
    # require z1 < m
    # so that a1 <= 2^64
    z1 = mod(z1, m)
    # reduce u * z1 by dividing by m1
    # and use the fact that u * m1 == -m0 (mod m)
    # so z == z0 + u * z2 (mod m)
    a1, a0 = fldmod(mod(z1, m), m1)
    z2 = sub_mod(a0 << 64, a1 * m0, m)
    # reduce u * z2 in the same way
    # so z == z0 - z3 (mod m)
    b1, b0 = fldmod(z2, m1)
    z3 = sub_mod(b1 * m0, b0 << 64, m)
    # final sub_mod gets the answer
    return sub_mod(z0, z3, m)
end

function mul_mod_big(x::UInt128, y::UInt128, m::UInt128)
    return convert(UInt128, mod(big(x) * big(y), big(m)))
end

function mul_mod_bad(x::UInt128, y::UInt128, m::UInt128)
    return mod(x * y, m)
end

function mul_mod_256(x::UInt128, y::UInt128, m::UInt128)
    return convert(UInt128, mod(UInt256(x) * UInt256(y), UInt256(m)))
end

function check(x, y, m)
    z0 = mul_mod(x, y, m)
    z1 = mul_mod_big(x, y, m)
    @assert z0 == z1
end

function check_rand(n=1_000_000)
    for _ in 1:n
        check(rand(UInt128), rand(UInt128), rand(UInt128))
    end
end

function benchmark(f=mul_mod)
    @benchmark $f(x, y, m) setup=(x=rand(UInt128); y=rand(UInt128); m=rand(UInt128))
end
4 Likes

Here’s a 81ns version (and it’s short!).

This version still divides out 64 bits at a time, but directly accumulates the answer in a UInt256.

The while loops are O(1) in the average case, but I’m not sure what the worst case is. Maybe they can be improved (I didn’t think long enough, but they do need to be >1 sometimes).

Code
using BitIntegers: UInt256

function mul_mod(x::UInt128, y::UInt128, m::UInt128)
    # M is a multiple of m such that the top bit is set
    M = m << leading_zeros(m)
    @assert leading_zeros(M) == 0
    # express M = (m1 << 64) + m0
    m0 = (M % UInt64) % UInt128
    m1 = M >> 64
    @assert M == (m1 << 64) + m0
    # take the product
    z0 = UInt256(x) * UInt256(y)
    # eliminate the top 64 bits of z0
    q1 = fld(UInt128(z0 >> 128), m1)
    z1 = z0 - (UInt256(M) * UInt256(q1) << 64)
    # correct for small overflow
    while z1 > z0
        z1 += (UInt256(M) << 64)
    end
    # eliminate the next 64 bits of z0
    q2 = fld(UInt128(z1 >> 64), m1)
    z2 = z1 - (UInt256(M) * UInt256(q2))
    # correct for small overflow
    while z2 > z1
        z2 += (UInt256(M) << 0)
    end
    # final answer
    return UInt128(z2) % m
end
3 Likes

Nice, thank you!

It is possible to do this without any divisions beyond the final % operation. q should be calculated using floating point multiplication, not integer division. For UInt128s, it should take 3 rounds consisting of Float64 multiplication of z by inv(m1) to compute q, rounding q to Int64, and multiplying of q by M to compute what should be subtracted from z. Each round should be capable of shaving off at least 50 bits each without any corrective while loops.

1 Like