# Modular multiplication without overflow

Let’s say I have three integers of the same type and want to compute `(a * b) % c`. `a * b` might overflow producing wrong results even though the mathematically correct answer (e.g. using `BigInt`s) would not overflow. What is the best way to compute this expression?

`a`, `b`, and `c` may be `Int128`s, and conversion to `BigInt` is not okay for performance reasons.

3 Likes

there isn’t a great way to do this unfortunately (at least if `c` is somewhat close to typemax). If `c` is constant for a lot of operations you can do some fun stuff though.

2 Likes

My initial thought was to do

``````( (a % c) * (b % c) ) % c
``````

but the bracketed product can still overflow.
cf. Modular arithmetic - Wikipedia
and
Example implementations

Here’s a reference I don’t claim to fully understand that may be a strategy to prevent overflow by multiplying out: Labor of Division (Episode III): Faster Unsigned Division by Constants

1 Like

Translating the (completely opaque and unverified) wikipedia code to an equally dodgy julia version:

Listing: `mul_mod(a,b,c)`
``````function mul_mod(a::UInt128, b::UInt128, m::Integer)

magic1 = (UInt128(0xFFFFFFFF) << 32)
magic2 = UInt128(0x8000000000000000)

if iszero(((a | b) & magic1))
return (a * b) % m
end

d = zero(UInt128)
mp2 = m >> 1

if a >= m; a %= m; end
if b >= m; b %= m; end

for _ in 1:64
(d > mp2) ? d = ((d << 1) - m) : d = (d << 1)
if !iszero(a & magic2)
d += b
end
if d >= m
d -= m
end
a <<= 1
end

return d
end

function mul_mod(a::Integer, b::Integer, m::Integer)
sa, sb = UInt128.(unsigned.((a,b)))
return mul_mod(sa, sb, m)
end
``````

“Testing”:

``````julia> a, b, c = (4856388669903036, 6493747681980967, 3860134847225580)
(4856388669903036, 6493747681980967, 3860134847225580)

julia> (big(a) * big(b)) % big(c)  ## Correct, but using BigInt
3059575088694852

julia> (a * b) % c ## Incorrect
2569083244915888

julia> ((a % c) * (b % c)) % c  ## Also incorrect due to overflow of multiplication
226995574897440

julia> mul_mod(a, b, c) |> Int128  ## Correct, without BigInt conversion
3059575088694852``````
2 Likes

Wow! That’s 3x faster than the bigint verison! Mind if I steal that and put it in GitHub - JuliaMath/IntegerMathUtils.jl?

2 Likes

An unrolled arbitrary precision multiplication and an unrolled arbitrary precision division with remainder could do this, but I imagine it would be pretty ugly and take about 40-80 operations. Integrating the two might help, though.

Sure - it’s a line by line translation from the Wikipedia page and I make no claims of correctness!

1 Like

Looks kind of like the algorithm explained here: c++ - Ways to do modulo multiplication with primitive types - Stack Overflow

2 Likes

How about expressing `(a*b)%c` in base `2^64` but perform the calculations with 128-bit precision so you don’t need to worry about overflow?

`a`, `b`, and `c` may be `Int128`s already, and performing calculations in 256-bit precision is slower.

1 Like

Mind if I steal that and put it in GitHub - JuliaMath/IntegerMathUtils.jl ?

Be careful with code you don’t understand, especially when it is advertised as “completely opaque and unverified”. Here’s an example where it is wrong:

``````julia> a,b,m = (0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)
(0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)

julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x2f601dd0ad7828c5366d3f1de8757516

julia> d = wiki_mul_mod(a, b, m) # Wrong, idk why.
0x1a5fe3638d5932abb00197a805401041
``````

Here’s another, much longer, slightly slower, fairly opaque, and reasonably well verified approach I crafted:

``````using Test, BenchmarkTools

"""
mul(a::T, b::T) -> x1::T, x1::unsigned(T)

Multiply `a` and `b` and return the high and low parts of the result.

``a*b = x1*2^n + x2`` where ``n`` is the number of bits in `T`.
"""
function mul(a::T, b::T) where T<:Integer
shift = sizeof(T) << 2
lo_mask = one(T) << shift - one(T)
a_lo = a & lo_mask
b_lo = b & lo_mask
a_hi = a >> shift
b_hi = b >> shift

x2 = unsigned(a_lo * b_lo)

m1 = a_lo * b_hi
m2 = a_hi * b_lo

x2, o1 = Base.add_with_overflow(x2, unsigned(m1) << shift)
x2, o2 = Base.add_with_overflow(x2, unsigned(m2) << shift)

x1 = a_hi * b_hi
x1 += m1 >> shift
x1 += m2 >> shift
x1 += o1
x1 += o2

x1, x2
end

@testset "mul" begin
for T in (UInt8, Int8), a in typemin(T):typemax(T), b in typemin(T):typemax(T)
x1, x2 = mul(a, b)
x = Int(a) * Int(b)
@test x1 * 256^sizeof(T) + x2 == x
end
end

"""
subtractand(x::T, m_hi::T, m_lo::T) -> x1::T, x2::T

Guess a `q` and compute `x1 = m_hi*q` and `x2 = m_lo*q`.

Precondition: ``leading_zeros(m_hi) = n-1`` and ``n ≤ leading_zeros(m_lo)``

Postcondition: ``0 ≤ x - y < 4*2^n`` and ``m_hi * 2^n + m_lo | y``

Where ``y = x1 * 2^n + x2`` and ``n`` is the half the number of bits in `T`.

!!! Warning
this docstring might contain lies.
"""
function subtractand(x::T, m_hi::T, m_lo::T) where T<:Unsigned
shift = sizeof(T) << 2
q = fld(x, m_hi + oneunit(T))
x1 = m_hi * q
x2 = m_lo * q

x1, x2
end

@testset "subtractand" begin
for T in (UInt8,), x in typemin(T):typemax(T), m_hi in T(0x10):T(0x1f), m_lo in T(0x00):T(0x0f)
x1, x2 = subtractand(x, m_hi, m_lo)
y = x1 * 16^sizeof(T) + x2
@test 0 <= x * 16^sizeof(T) - y < 4*256^sizeof(T)
@test y % (m_hi * 16^sizeof(T) + m_lo) == 0
end
x1, x2 = subtractand(0x57c2, 0x013d, 0x0066)
y = x1 * 16^sizeof(x1) + x2
@test 0 <= 0x57c2 * 16^sizeof(x1) - y < 4*256^sizeof(x1)
@test y % (0x013d * 16^sizeof(x1) + 0x0066) == 0
end

"""
mul_mod(a::T, b::T, m::T) -> x::T

Compute `(a * b) % m` without intermediate overflow

Precondition ``m > 0``
"""
function mul_mod(a::T, b::T, m::T) where T<:Unsigned
shift = sizeof(T) << 2

m <= one(T) << shift && return ((a%m) * (b%m)) % m

x1, x2 = mul(a, b)
# Invariant: x1 * 2^n + x2 ≡ a * b (mod m)
#answer = widemul(a, b) % m
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer

m1 = m << (leading_zeros(m) + 1)
m_hi = m1 >> shift + oneunit(T) << shift
m_lo = m1 & (oneunit(T) << shift - oneunit(T))

# m_hi and m_lo satisfy the preconditions of subtractand and form a multiple of m.
#@assert leading_zeros(m_hi) == shift-1
#@assert leading_zeros(m_lo) >= shift
#@assert (widen(m_hi) << shift + m_lo) % m == 0

s1, s2 = subtractand(x1, m_hi, m_lo)

x1 -= s1; x1 -= s2 >> shift; (x2, o) = Base.sub_with_overflow(x2, s2 << shift); x1 -= o

#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert leading_zeros(x1) ≥ shift - 2                # Status

y = x1 << (shift-2) + x2 >> (shift+2)
s1, s2 = subtractand(y, m_hi, m_lo)

x1 -= s1 >> (shift-2); x1 -= s2 >> (2shift-2); (x2, o) = Base.sub_with_overflow(x2, s1 << (shift+2)); x1 -= o; (x2, o) = Base.sub_with_overflow(x2, s2 << 2); x1 -= o

#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert leading_zeros(x1) ≥ 2shift - 4               # Status

y = x1 << (2shift-4) + x2 >> 4

m_lo >>= 1
m_lo += (m_hi & 0x01) << (shift-1)
m_hi >>= 1

q3 = fld(y, m >> 4 + one(T))
s1, s2 = mul(q3, m)

x1 -= s1; x2, o = Base.sub_with_overflow(x2, s2); x1 -= o

#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert x1 ≤ 1                                       # Status

x2 -= m*(x1 == 1 || x2 > m)

#@assert x2 % m == answer                             # Correct

x2 % m
end

function test_mul_mod(f, types)
function g(::Type{T}) where T
for _ in 1:50_000
a = rand(T)
b = rand(T)
m = rand(T)
m = m == 0 ? one(T) : m
f(a, b, m) === widemul(a, b) % m % typeof(m) || (println((a,b,m));return false)
end
true
end
for i in 1:10
@test all(g, types)
print(i)
end
for i in 1:10
for _ in 1:100
@test g(first(types))
end
print(i)
end
end

@testset "mul_mod" begin
test_mul_mod(mul_mod, (UInt16, UInt32, UInt64, UInt128))
end

function wiki_mul_mod(a::UInt128, b::UInt128, m::Integer)

magic1 = (UInt128(0xFFFFFFFF) << 32)
magic2 = UInt128(0x8000000000000000)

if iszero(((a | b) & magic1))
return (a * b) % m
end

d = zero(UInt128)
mp2 = m >> 1

if a >= m; a %= m; end
if b >= m; b %= m; end

for _ in 1:64
(d > mp2) ? d = ((d << 1) - m) : d = (d << 1)
if !iszero(a & magic2)
d += b
end
if d >= m
d -= m
end
a <<= 1
end

return d
end

function wiki_mul_mod(a::Integer, b::Integer, m::Integer)
sa, sb = UInt128.(unsigned.((a,b)))
return wiki_mul_mod(sa, sb, m)
end

# @testset "wiki_mul_mod" begin [FAILS]
#     test_mul_mod(wiki_mul_mod, (UInt128,))
# end

wide_mul_mod(a, b, m) = widemul(a, b) % m % typeof(m)

@testset "wide_mul_mod" begin
test_mul_mod(wide_mul_mod, (UInt8, UInt16, UInt32, UInt64, UInt128))
end

for T in (UInt16, UInt32, UInt64, UInt128)
println(T)
for f in (mul_mod, wide_mul_mod, wiki_mul_mod)
print(lpad(f, 12))
#@btime \$f(a, b, c) setup=(a=rand(\$T); b=rand(\$T); c=rand(\$T))
b = @benchmark \$f(a, b, c) setup=(a=rand(\$T); b=rand(\$T); c=rand(\$T))
med = median(b) # Median rather than min to avoid timing the small input fastpath
println("  ", BenchmarkTools.prettytime(med.time),
" (", med.allocs , " allocation", med.allocs == 1 ? "" : "s", ": ",
BenchmarkTools.prettymemory(med.memory), ")")
end
end
``````

It’s not as fast as the (broken) Wikipedia implementation on UInt128s, nor fast enough for #47577, but still pretty quick. Runtime in the 128-bit case is dominated by two challenging divisions of 128-bit numbers by 64-bit numbers. My implementation is listed as `mul_mod` and `widemul(a, b) % m % typeof(m)` is listed as `wide_mul_mod`.

``````UInt16
mul_mod  33.806 ns (0 allocations: 0 bytes)
wide_mul_mod  5.940 ns (0 allocations: 0 bytes)
wiki_mul_mod  23.164 ns (0 allocations: 0 bytes)
UInt32
mul_mod  30.852 ns (0 allocations: 0 bytes)
wide_mul_mod  12.286 ns (0 allocations: 0 bytes)
wiki_mul_mod  23.498 ns (0 allocations: 0 bytes)
UInt64
mul_mod  48.651 ns (0 allocations: 0 bytes)
wide_mul_mod  204.355 ns (0 allocations: 0 bytes)
wiki_mul_mod  218.973 ns (0 allocations: 0 bytes)
UInt128
mul_mod  470.913 ns (0 allocations: 0 bytes)
wide_mul_mod  858.755 ns (15 allocations: 256 bytes)
wiki_mul_mod  234.104 ns (0 allocations: 0 bytes)
``````
4 Likes

This looks wrong. In the original code there was a loop `for (i = 0; i < 64; ++i)`, but that was for `uint64_t`, whereas you are working with `UInt128`. If you look at the stackoverflow post that I linked above, you’ll see that it is describing a variant of long multiplication in binary which operates on the bits `a` (of which there are now 128).

I would try adapting the stackoverflow code, which is a lot better explained and seems like it should have similar or better performance.

1 Like

And the `magic` numbers need to be upgraded also:

``````magic1 = (UInt128(0xFFFFFFFFFFFFFFFF) << 64)
magic2 = UInt128(0x80000000000000000000000000000000)
``````

Testing with the `(a,b,m)` given above with 128 round loop gives the right answer.

I would try adapting the stackoverflow code, which is a lot better explained and seems like it should have similar or better performance.

I copied and pasted @jd-forester’s code verbatim to perform benchmarks and tests and devised a different algorithm from scratch; I did not adapt any existing code.

Changing the 64 to 128 as @stevengj suggested and the magic numbers as @Dan suggested yields the counterexample `(0x350caac24c5180855a2891ec9a655250, 0x5d043f9d0c547441717da21f58a17d54, 0xac4dd9d313c605d3e9c794646437b1b9)`, which I expect will stand as the wiki page notes that m must have fewer than 64 bits and I imagine that carries forward to m must have fewer than 128 bits once extended.

In addition to still not being correct over the whole domain, I benchmark the revised version as comparable in speed to the correct (I think) version I posted.

``````          mul_mod   wide_mul_mod  steve_dan_wiki_mul_mod (still broken)
UInt16    43.684 ns    8.239 ns   31.100 ns
UInt32    40.280 ns   14.272 ns   30.405 ns
UInt64    60.862 ns  268.941 ns  274.094 ns
UInt128  568.878 ns    1.000 μs  560.923 ns
``````

My version does not include a loop over every binary digit and as far as I can tell, all the linked implementations do (except the extended precision floating point trick).

Managed to get `wiki_mul_mod` to not break over new test-case. The updated code is:

Updated `wiki_mul_mod`
``````function wiki_mul_mod(a::UInt128, b::UInt128, m::UInt128)

magic1 = (UInt128(0xFFFFFFFFFFFFFFFF) << 64)
magic2 = UInt128(0x80000000000000000000000000000000)

if iszero(((a | b) & magic1))
return (a * b) % m
end

d = zero(UInt128)
mp2 = m >> 1

if a >= m
a %= m
end
if b >= m
b %= m
end

mc = (typemax(UInt128) - m) + 1

for t = 1:128
d2 =
d <= (typemax(UInt128) >> 1) ? d << 1 :
typemax(UInt128) - ((typemax(UInt128) - d + 1) << 1) + 1 + mc
d = (d2 >= m) ? (d2 - m) : d2
if !iszero(a & magic2)
if b > (typemax(UInt128) - d)
d += b
d >= m && (d -= m)
d += mc
else
d += b
end
end
d >= m && (d -= m)
a <<= 1
end

return d
end
``````

It gives for test-case:

``````julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x659f2a534b3977bb8841b239b086c2b4

julia> d = wiki_mul_mod(a, b, m)
0x659f2a534b3977bb8841b239b086c2b4

julia> d = wiki_mul_mod(b, a, m)
0x659f2a534b3977bb8841b239b086c2b4
``````

In any case, these sort of primitives should be formally verified against an arithmetic model (using SMT). Because bugs and edge cases can be hard to detect.

3 Likes

A cursory check with my fuzzer seems to agree:

``````julia> gen = PropCheck.tuple(3, igen(UInt128));

julia> function prop(a,b,c)
wiki_mul_mod(a,b,c) == ((big(a) * big(b)) % big(c))
end
prop (generic function with 1 method)

julia> check(Splat(prop), gen)
true
``````

for comparison, it fails almost immediately for earlier versions:

``````julia> function prop2(a,b,c)
mul_mod(a,b,c) == ((big(a) * big(b)) % big(c))
end
prop2 (generic function with 1 method)

julia> check(Splat(prop2), gen)
[ Info: Found counterexample for 'Splat(prop2)', beginning shrinking...
┌ Info: 75 counterexamples found, of which 3 threw 2 distinct exception types
│   Errors =
│    2-element Vector{Any}:
│     false
└          DivideError()
((0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000), DivideError())

julia> check(Splat(prop2), gen)
[ Info: Found counterexample for 'Splat(prop2)', beginning shrinking...
[ Info: 134 counterexamples found
(0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001)
``````

To be fair, the `DivideError` also happens with the updated versions:

``````julia> prop(rand(UInt128), rand(UInt128), zero(UInt128))
ERROR: DivideError: integer division error
Stacktrace:
[1] rem
@ ./int.jl:1033 [inlined]
[2] wiki_mul_mod(a::UInt128, b::UInt128, m::UInt128)
@ Main ./REPL[2]:14
[3] prop(a::UInt128, b::UInt128, c::UInt128)
@ Main ./REPL[18]:2
[4] top-level scope
@ REPL[33]:1
``````

so that should be guarded behind some checks and I should finally get around to documenting my fuzzer (+adding more features so it’s actually useable)

2 Likes

Checked this counterexample and it satisfied property. Any more details?

``````julia> a, b, m = 0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001
(0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001)

julia> d = wiki_mul_mod(a, b, m)
0x0000000000000000ffffffff00000001

julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x0000000000000000ffffffff00000001

julia> c-d
0x00000000000000000000000000000000
``````

UPADTE: I get it… this is counterexample for `mul_mod`, not `wiki_mul_mod`.

The fuzzer tools looks awesome. Any link? (or perhaps more dev needed).

Which function is this a counterexample for? If you could provide source code or link to the post that contains it that would be great.

In addition, making it general for all type sizes and brute force checking for the smaller ones is a fairly reliable system that is simple and efficient enough to be included in CI. I generalized your code to

``````function wiki_mul_mod_post_18(a::T, b::T, m::T) where T <: Base.BitUnsigned
magic1 = (one(T) << (sizeof(T)<<2) - one(T)) << (sizeof(T)<<2)
magic2 = typemax(T)÷2+one(T)

if iszero(((a | b) & magic1))
return (a * b) % m
end

d = zero(T)
mp2 = m >> 1

if a >= m
a %= m
end
if b >= m
b %= m
end

mc = (typemax(T) - m) + one(T)

for t = 1:(sizeof(T)<<3)
d2 =
d <= (typemax(T) >> 1) ? d << 1 :
typemax(T) - ((typemax(T) - d + one(T)) << 1) + one(T) + mc
d = (d2 >= m) ? (d2 - m) : d2
if !iszero(a & magic2)
if b > (typemax(T) - d)
d += b
d >= m && (d -= m)
d += mc
else
d += b
end
end
d >= m && (d -= m)
a <<= 1
end

return d
end
``````

and found that it passes brute force tests for all `BitUnsigned` types.

It’s also pretty efficient for UInt128!

``````            mul_mod  wide_mul_mod  wiki_mul_mod_post_18
UInt64    46.574 ns   211.096 ns   136.679 ns
UInt128  437.209 ns   659.184 ns   396.000 ns
``````
2 Likes

It’s a quickcheck/hedgehog inspired property based testing tool! You can find the repo here, though the documentation that is up on github is much too bare. There’s also lots of bugs that need fixing, like not being able to generate ranges of things properly. I encountered that bug while writing more documentation, and haven’t gotten around to fixing it yet

That’s for the original code in this post you mentioned was wrong but didn’t know why. I wanted to give a smaller example, which is (in part) exactly what that fuzzing tool is for

3 Likes

When you can widen `a` and `b` to another native type, the problem is trivial. Even at max-width types, the problem is trivial when `a*b` does not overflow. If `a*b` does overflow, you can still use something like a `(hi,lo) == hilomul(a,b)` such that `a*b == hi*z+lo` with `z=typemax(T)+1` (we don’t seem to have this function in `Base`? but it’s pretty easy in any case) and then use the identity `mod(hi*z+lo,c) == mod(mod(mod(hi,c)*mod(z,c),c) + mod(lo,c),c)` (I just ripped the `mod` identities off wikipedia – check me) so long as the intermediate product does not overflow (sufficient that `c*c` does not overflow). Obviously `mod(z,c)` has to be computed carefully here since `z` is unrepresentable, but this isn’t so hard.

It seems that you’d only need to resort to the complicated (and presumably much slower) version when all the above contingencies fail. If we have to fall back to some slow path for `c` that are unrepresentable in 64-bits, we’d probably be willing to live with that.

1 Like