Modular multiplication without overflow

How about expressing (a*b)%c in base 2^64 but perform the calculations with 128-bit precision so you don’t need to worry about overflow?

a, b, and c may be Int128s already, and performing calculations in 256-bit precision is slower.

1 Like

Mind if I steal that and put it in GitHub - JuliaMath/IntegerMathUtils.jl ?

Be careful with code you don’t understand, especially when it is advertised as “completely opaque and unverified”. Here’s an example where it is wrong:

julia> a,b,m = (0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)
(0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)

julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x2f601dd0ad7828c5366d3f1de8757516

julia> d = wiki_mul_mod(a, b, m) # Wrong, idk why.
0x1a5fe3638d5932abb00197a805401041

Here’s another, much longer, slightly slower, fairly opaque, and reasonably well verified approach I crafted:

using Test, BenchmarkTools

"""
    mul(a::T, b::T) -> x1::T, x1::unsigned(T)

Multiply `a` and `b` and return the high and low parts of the result.

``a*b = x1*2^n + x2`` where ``n`` is the number of bits in `T`.
"""
function mul(a::T, b::T) where T<:Integer
    shift = sizeof(T) << 2
    lo_mask = one(T) << shift - one(T)
    a_lo = a & lo_mask
    b_lo = b & lo_mask
    a_hi = a >> shift
    b_hi = b >> shift

    x2 = unsigned(a_lo * b_lo)

    m1 = a_lo * b_hi
    m2 = a_hi * b_lo

    x2, o1 = Base.add_with_overflow(x2, unsigned(m1) << shift)
    x2, o2 = Base.add_with_overflow(x2, unsigned(m2) << shift)

    x1 = a_hi * b_hi
    x1 += m1 >> shift
    x1 += m2 >> shift
    x1 += o1
    x1 += o2

    x1, x2
end

@testset "mul" begin
    for T in (UInt8, Int8), a in typemin(T):typemax(T), b in typemin(T):typemax(T)
        x1, x2 = mul(a, b)
        x = Int(a) * Int(b)
        @test x1 * 256^sizeof(T) + x2 == x
    end
end

"""
    subtractand(x::T, m_hi::T, m_lo::T) -> x1::T, x2::T

Guess a `q` and compute `x1 = m_hi*q` and `x2 = m_lo*q`.

Precondition: ``leading_zeros(m_hi) = n-1`` and ``n ≤ leading_zeros(m_lo)``

Postcondition: ``0 ≤ x - y < 4*2^n`` and ``m_hi * 2^n + m_lo | y``

Where ``y = x1 * 2^n + x2`` and ``n`` is the half the number of bits in `T`.

!!! Warning
    this docstring might contain lies.
"""
function subtractand(x::T, m_hi::T, m_lo::T) where T<:Unsigned
    shift = sizeof(T) << 2
    q = fld(x, m_hi + oneunit(T))
    x1 = m_hi * q
    x2 = m_lo * q

    x1, x2
end

@testset "subtractand" begin
    for T in (UInt8,), x in typemin(T):typemax(T), m_hi in T(0x10):T(0x1f), m_lo in T(0x00):T(0x0f)
        x1, x2 = subtractand(x, m_hi, m_lo)
        y = x1 * 16^sizeof(T) + x2
        @test 0 <= x * 16^sizeof(T) - y < 4*256^sizeof(T)
        @test y % (m_hi * 16^sizeof(T) + m_lo) == 0
    end
    x1, x2 = subtractand(0x57c2, 0x013d, 0x0066)
    y = x1 * 16^sizeof(x1) + x2
    @test 0 <= 0x57c2 * 16^sizeof(x1) - y < 4*256^sizeof(x1)
    @test y % (0x013d * 16^sizeof(x1) + 0x0066) == 0
end

"""
    mul_mod(a::T, b::T, m::T) -> x::T

Compute `(a * b) % m` without intermediate overflow

Precondition ``m > 0``
"""
function mul_mod(a::T, b::T, m::T) where T<:Unsigned
    shift = sizeof(T) << 2

    m <= one(T) << shift && return ((a%m) * (b%m)) % m

    x1, x2 = mul(a, b)
    # Invariant: x1 * 2^n + x2 ≡ a * b (mod m)
    #answer = widemul(a, b) % m
    #@assert (widen(x1) << 8sizeof(T) + x2) % m == answer

    m1 = m << (leading_zeros(m) + 1)
    m_hi = m1 >> shift + oneunit(T) << shift
    m_lo = m1 & (oneunit(T) << shift - oneunit(T))

    # m_hi and m_lo satisfy the preconditions of subtractand and form a multiple of m.
    #@assert leading_zeros(m_hi) == shift-1
    #@assert leading_zeros(m_lo) >= shift
    #@assert (widen(m_hi) << shift + m_lo) % m == 0

    s1, s2 = subtractand(x1, m_hi, m_lo)

    x1 -= s1; x1 -= s2 >> shift; (x2, o) = Base.sub_with_overflow(x2, s2 << shift); x1 -= o

    #@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
    #@assert leading_zeros(x1) ≥ shift - 2                # Status

    y = x1 << (shift-2) + x2 >> (shift+2)
    s1, s2 = subtractand(y, m_hi, m_lo)

    x1 -= s1 >> (shift-2); x1 -= s2 >> (2shift-2); (x2, o) = Base.sub_with_overflow(x2, s1 << (shift+2)); x1 -= o; (x2, o) = Base.sub_with_overflow(x2, s2 << 2); x1 -= o

    #@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
    #@assert leading_zeros(x1) ≥ 2shift - 4               # Status

    y = x1 << (2shift-4) + x2 >> 4

    m_lo >>= 1
    m_lo += (m_hi & 0x01) << (shift-1)
    m_hi >>= 1

    q3 = fld(y, m >> 4 + one(T))
    s1, s2 = mul(q3, m)

    x1 -= s1; x2, o = Base.sub_with_overflow(x2, s2); x1 -= o

    #@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
    #@assert x1 ≤ 1                                       # Status

    x2 -= m*(x1 == 1 || x2 > m)

    #@assert x2 % m == answer                             # Correct

    x2 % m
end

function test_mul_mod(f, types)
    function g(::Type{T}) where T
        for _ in 1:50_000
            a = rand(T)
            b = rand(T)
            m = rand(T)
            m = m == 0 ? one(T) : m
            f(a, b, m) === widemul(a, b) % m % typeof(m) || (println((a,b,m));return false)
        end
        true
    end
    for i in 1:10
        @test all(g, types)
        print(i)
    end
    for i in 1:10
        for _ in 1:100
            @test g(first(types))
        end
        print(i)
    end
end

@testset "mul_mod" begin
    test_mul_mod(mul_mod, (UInt16, UInt32, UInt64, UInt128))
end


function wiki_mul_mod(a::UInt128, b::UInt128, m::Integer)

    magic1 = (UInt128(0xFFFFFFFF) << 32)
    magic2 = UInt128(0x8000000000000000)

    if iszero(((a | b) & magic1))
        return (a * b) % m
    end

    d = zero(UInt128)
    mp2 = m >> 1

    if a >= m; a %= m; end
    if b >= m; b %= m; end

    for _ in 1:64
        (d > mp2) ? d = ((d << 1) - m) : d = (d << 1)
        if !iszero(a & magic2)
            d += b
        end
        if d >= m
            d -= m
        end
        a <<= 1
    end

    return d
end

function wiki_mul_mod(a::Integer, b::Integer, m::Integer)
    sa, sb = UInt128.(unsigned.((a,b)))
    return wiki_mul_mod(sa, sb, m)
end

# @testset "wiki_mul_mod" begin [FAILS]
#     test_mul_mod(wiki_mul_mod, (UInt128,))
# end

wide_mul_mod(a, b, m) = widemul(a, b) % m % typeof(m)

@testset "wide_mul_mod" begin
    test_mul_mod(wide_mul_mod, (UInt8, UInt16, UInt32, UInt64, UInt128))
end

for T in (UInt16, UInt32, UInt64, UInt128)
    println(T)
    for f in (mul_mod, wide_mul_mod, wiki_mul_mod)
        print(lpad(f, 12))
        #@btime $f(a, b, c) setup=(a=rand($T); b=rand($T); c=rand($T))
        b = @benchmark $f(a, b, c) setup=(a=rand($T); b=rand($T); c=rand($T))
        med = median(b) # Median rather than min to avoid timing the small input fastpath
        println("  ", BenchmarkTools.prettytime(med.time),
            " (", med.allocs , " allocation", med.allocs == 1 ? "" : "s", ": ",
            BenchmarkTools.prettymemory(med.memory), ")")
    end
end

It’s not as fast as the (broken) Wikipedia implementation on UInt128s, nor fast enough for #47577, but still pretty quick. Runtime in the 128-bit case is dominated by two challenging divisions of 128-bit numbers by 64-bit numbers. My implementation is listed as mul_mod and widemul(a, b) % m % typeof(m) is listed as wide_mul_mod.

UInt16
     mul_mod  33.806 ns (0 allocations: 0 bytes)
wide_mul_mod  5.940 ns (0 allocations: 0 bytes)
wiki_mul_mod  23.164 ns (0 allocations: 0 bytes)
UInt32
     mul_mod  30.852 ns (0 allocations: 0 bytes)
wide_mul_mod  12.286 ns (0 allocations: 0 bytes)
wiki_mul_mod  23.498 ns (0 allocations: 0 bytes)
UInt64
     mul_mod  48.651 ns (0 allocations: 0 bytes)
wide_mul_mod  204.355 ns (0 allocations: 0 bytes)
wiki_mul_mod  218.973 ns (0 allocations: 0 bytes)
UInt128
     mul_mod  470.913 ns (0 allocations: 0 bytes)
wide_mul_mod  858.755 ns (15 allocations: 256 bytes)
wiki_mul_mod  234.104 ns (0 allocations: 0 bytes)
4 Likes

This looks wrong. In the original code there was a loop for (i = 0; i < 64; ++i), but that was for uint64_t, whereas you are working with UInt128. If you look at the stackoverflow post that I linked above, you’ll see that it is describing a variant of long multiplication in binary which operates on the bits a (of which there are now 128).

I would try adapting the stackoverflow code, which is a lot better explained and seems like it should have similar or better performance.

1 Like

And the magic numbers need to be upgraded also:

magic1 = (UInt128(0xFFFFFFFFFFFFFFFF) << 64)
magic2 = UInt128(0x80000000000000000000000000000000)

Testing with the (a,b,m) given above with 128 round loop gives the right answer.

I would try adapting the stackoverflow code, which is a lot better explained and seems like it should have similar or better performance.

I copied and pasted @jd-forester’s code verbatim to perform benchmarks and tests and devised a different algorithm from scratch; I did not adapt any existing code.

Changing the 64 to 128 as @stevengj suggested and the magic numbers as @Dan suggested yields the counterexample (0x350caac24c5180855a2891ec9a655250, 0x5d043f9d0c547441717da21f58a17d54, 0xac4dd9d313c605d3e9c794646437b1b9), which I expect will stand as the wiki page notes that m must have fewer than 64 bits and I imagine that carries forward to m must have fewer than 128 bits once extended.

In addition to still not being correct over the whole domain, I benchmark the revised version as comparable in speed to the correct (I think) version I posted.

          mul_mod   wide_mul_mod  steve_dan_wiki_mul_mod (still broken)
UInt16    43.684 ns    8.239 ns   31.100 ns
UInt32    40.280 ns   14.272 ns   30.405 ns
UInt64    60.862 ns  268.941 ns  274.094 ns
UInt128  568.878 ns    1.000 μs  560.923 ns

My version does not include a loop over every binary digit and as far as I can tell, all the linked implementations do (except the extended precision floating point trick).

Managed to get wiki_mul_mod to not break over new test-case. The updated code is:

Updated `wiki_mul_mod`
function wiki_mul_mod(a::UInt128, b::UInt128, m::UInt128)
                              
    magic1 = (UInt128(0xFFFFFFFFFFFFFFFF) << 64)
    magic2 = UInt128(0x80000000000000000000000000000000)
                              
    if iszero(((a | b) & magic1))
        return (a * b) % m
    end

    d = zero(UInt128)
    mp2 = m >> 1

    if a >= m
        a %= m
    end
    if b >= m
        b %= m
    end

    mc = (typemax(UInt128) - m) + 1

    for t = 1:128
        d2 =
            d <= (typemax(UInt128) >> 1) ? d << 1 :
            typemax(UInt128) - ((typemax(UInt128) - d + 1) << 1) + 1 + mc
        d = (d2 >= m) ? (d2 - m) : d2
        if !iszero(a & magic2) 
            if b > (typemax(UInt128) - d)
                d += b
                d >= m && (d -= m)
                d += mc
            else
                d += b
            end
        end
        d >= m && (d -= m)
        a <<= 1
    end

    return d
end

It gives for test-case:

julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x659f2a534b3977bb8841b239b086c2b4

julia> d = wiki_mul_mod(a, b, m)
0x659f2a534b3977bb8841b239b086c2b4

julia> d = wiki_mul_mod(b, a, m)
0x659f2a534b3977bb8841b239b086c2b4

In any case, these sort of primitives should be formally verified against an arithmetic model (using SMT). Because bugs and edge cases can be hard to detect.

3 Likes

A cursory check with my fuzzer seems to agree:

julia> gen = PropCheck.tuple(3, igen(UInt128));

julia> function prop(a,b,c)
         wiki_mul_mod(a,b,c) == ((big(a) * big(b)) % big(c))
       end
prop (generic function with 1 method)

julia> check(Splat(prop), gen)
true

for comparison, it fails almost immediately for earlier versions:

julia> function prop2(a,b,c)
         mul_mod(a,b,c) == ((big(a) * big(b)) % big(c))
       end
prop2 (generic function with 1 method)

julia> check(Splat(prop2), gen)
[ Info: Found counterexample for 'Splat(prop2)', beginning shrinking...
┌ Info: 75 counterexamples found, of which 3 threw 2 distinct exception types
│   Errors =
│    2-element Vector{Any}:
│     false
└          DivideError()
((0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000), DivideError())

julia> check(Splat(prop2), gen)
[ Info: Found counterexample for 'Splat(prop2)', beginning shrinking...
[ Info: 134 counterexamples found
(0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001)

To be fair, the DivideError also happens with the updated versions:

julia> prop(rand(UInt128), rand(UInt128), zero(UInt128))
ERROR: DivideError: integer division error
Stacktrace:
 [1] rem
   @ ./int.jl:1033 [inlined]
 [2] wiki_mul_mod(a::UInt128, b::UInt128, m::UInt128)
   @ Main ./REPL[2]:14
 [3] prop(a::UInt128, b::UInt128, c::UInt128)
   @ Main ./REPL[18]:2
 [4] top-level scope
   @ REPL[33]:1

so that should be guarded behind some checks and I should finally get around to documenting my fuzzer (+adding more features so it’s actually useable) :slight_smile:

2 Likes

Checked this counterexample and it satisfied property. Any more details?

julia> a, b, m = 0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001
(0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001)

julia> d = wiki_mul_mod(a, b, m)
0x0000000000000000ffffffff00000001

julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x0000000000000000ffffffff00000001

julia> c-d
0x00000000000000000000000000000000

UPADTE: I get it… this is counterexample for mul_mod, not wiki_mul_mod.

The fuzzer tools looks awesome. Any link? (or perhaps more dev needed).

Which function is this a counterexample for? If you could provide source code or link to the post that contains it that would be great.

In addition, making it general for all type sizes and brute force checking for the smaller ones is a fairly reliable system that is simple and efficient enough to be included in CI. I generalized your code to

function wiki_mul_mod_post_18(a::T, b::T, m::T) where T <: Base.BitUnsigned
    magic1 = (one(T) << (sizeof(T)<<2) - one(T)) << (sizeof(T)<<2)
    magic2 = typemax(T)÷2+one(T)

    if iszero(((a | b) & magic1))
        return (a * b) % m
    end

    d = zero(T)
    mp2 = m >> 1

    if a >= m
        a %= m
    end
    if b >= m
        b %= m
    end

    mc = (typemax(T) - m) + one(T)

    for t = 1:(sizeof(T)<<3)
        d2 =
            d <= (typemax(T) >> 1) ? d << 1 :
            typemax(T) - ((typemax(T) - d + one(T)) << 1) + one(T) + mc
        d = (d2 >= m) ? (d2 - m) : d2
        if !iszero(a & magic2)
            if b > (typemax(T) - d)
                d += b
                d >= m && (d -= m)
                d += mc
            else
                d += b
            end
        end
        d >= m && (d -= m)
        a <<= 1
    end

    return d
end

and found that it passes brute force tests for all BitUnsigned types.

It’s also pretty efficient for UInt128!

            mul_mod  wide_mul_mod  wiki_mul_mod_post_18
UInt64    46.574 ns   211.096 ns   136.679 ns
UInt128  437.209 ns   659.184 ns   396.000 ns
2 Likes

It’s a quickcheck/hedgehog inspired property based testing tool! You can find the repo here, though the documentation that is up on github is much too bare. There’s also lots of bugs that need fixing, like not being able to generate ranges of things properly. I encountered that bug while writing more documentation, and haven’t gotten around to fixing it yet :sweat_smile:

That’s for the original code in this post you mentioned was wrong but didn’t know why. I wanted to give a smaller example, which is (in part) exactly what that fuzzing tool is for :slight_smile:

3 Likes

When you can widen a and b to another native type, the problem is trivial. Even at max-width types, the problem is trivial when a*b does not overflow. If a*b does overflow, you can still use something like a (hi,lo) == hilomul(a,b) such that a*b == hi*z+lo with z=typemax(T)+1 (we don’t seem to have this function in Base? but it’s pretty easy in any case) and then use the identity mod(hi*z+lo,c) == mod(mod(mod(hi,c)*mod(z,c),c) + mod(lo,c),c) (I just ripped the mod identities off wikipedia – check me) so long as the intermediate product does not overflow (sufficient that c*c does not overflow). Obviously mod(z,c) has to be computed carefully here since z is unrepresentable, but this isn’t so hard.

It seems that you’d only need to resort to the complicated (and presumably much slower) version when all the above contingencies fail. If we have to fall back to some slow path for c that are unrepresentable in 64-bits, we’d probably be willing to live with that.

1 Like

I was going to say that perhaps I’m being dense but doesn’t this definition work well for all bit sizes less that 128:

mulmod(a::Integer, b::Integer, c::Integer) =
    mod(widemul(a, b), c) % typeof(c)

The tricky case is how to do this for 128-bit arguments where c actually is too big to fit in a 64-bit integer.

Also note: regrettably we followed the C definition of % and it actually does rem whereas what everyone wants here is mod.

3 Likes

I agree that we should use native arithmetic (with widening) whenever possible. Although I’m not sure anyone is disputing that.

At the widest native types (eg Int128) I imagine it’s still worthwhile to try to complete the operation in native arithmetic. A run-time overflow check can slow-path to the algorithm discussed here when required, but the significant speed difference makes the “try native first” approach enticing. I don’t imagine that very large arguments are common in most domains.

1 Like

Very closely related: factorials and binomial coefficients! Both of those are useful to have modulo some other number, because both overflow very quickly.

Modular variants of those would certainly be useful to have! Not sure if they belong in base or a package though. Maybe ModularMethods.jl or something like that?

1 Like

The OP was specifically about UInt128.

2 Likes

Ah yeah, missed that!

I do think there should be a package specifically for modular arithmetic, but I think these two are basic enough they could go in base alongside the current modular power.

You say regrettably–are you thinking of changing it in Julia 2.0? I think it’d be nice for a lot of people coming from Python, at the very least.