Mind if I steal that and put it in GitHub - JuliaMath/IntegerMathUtils.jl ?
Be careful with code you don’t understand, especially when it is advertised as “completely opaque and unverified”. Here’s an example where it is wrong:
julia> a,b,m = (0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)
(0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)
julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x2f601dd0ad7828c5366d3f1de8757516
julia> d = wiki_mul_mod(a, b, m) # Wrong, idk why.
0x1a5fe3638d5932abb00197a805401041
Here’s another, much longer, slightly slower, fairly opaque, and reasonably well verified approach I crafted:
using Test, BenchmarkTools
"""
mul(a::T, b::T) -> x1::T, x1::unsigned(T)
Multiply `a` and `b` and return the high and low parts of the result.
``a*b = x1*2^n + x2`` where ``n`` is the number of bits in `T`.
"""
function mul(a::T, b::T) where T<:Integer
shift = sizeof(T) << 2
lo_mask = one(T) << shift - one(T)
a_lo = a & lo_mask
b_lo = b & lo_mask
a_hi = a >> shift
b_hi = b >> shift
x2 = unsigned(a_lo * b_lo)
m1 = a_lo * b_hi
m2 = a_hi * b_lo
x2, o1 = Base.add_with_overflow(x2, unsigned(m1) << shift)
x2, o2 = Base.add_with_overflow(x2, unsigned(m2) << shift)
x1 = a_hi * b_hi
x1 += m1 >> shift
x1 += m2 >> shift
x1 += o1
x1 += o2
x1, x2
end
@testset "mul" begin
for T in (UInt8, Int8), a in typemin(T):typemax(T), b in typemin(T):typemax(T)
x1, x2 = mul(a, b)
x = Int(a) * Int(b)
@test x1 * 256^sizeof(T) + x2 == x
end
end
"""
subtractand(x::T, m_hi::T, m_lo::T) -> x1::T, x2::T
Guess a `q` and compute `x1 = m_hi*q` and `x2 = m_lo*q`.
Precondition: ``leading_zeros(m_hi) = n-1`` and ``n ≤ leading_zeros(m_lo)``
Postcondition: ``0 ≤ x - y < 4*2^n`` and ``m_hi * 2^n + m_lo | y``
Where ``y = x1 * 2^n + x2`` and ``n`` is the half the number of bits in `T`.
!!! Warning
this docstring might contain lies.
"""
function subtractand(x::T, m_hi::T, m_lo::T) where T<:Unsigned
shift = sizeof(T) << 2
q = fld(x, m_hi + oneunit(T))
x1 = m_hi * q
x2 = m_lo * q
x1, x2
end
@testset "subtractand" begin
for T in (UInt8,), x in typemin(T):typemax(T), m_hi in T(0x10):T(0x1f), m_lo in T(0x00):T(0x0f)
x1, x2 = subtractand(x, m_hi, m_lo)
y = x1 * 16^sizeof(T) + x2
@test 0 <= x * 16^sizeof(T) - y < 4*256^sizeof(T)
@test y % (m_hi * 16^sizeof(T) + m_lo) == 0
end
x1, x2 = subtractand(0x57c2, 0x013d, 0x0066)
y = x1 * 16^sizeof(x1) + x2
@test 0 <= 0x57c2 * 16^sizeof(x1) - y < 4*256^sizeof(x1)
@test y % (0x013d * 16^sizeof(x1) + 0x0066) == 0
end
"""
mul_mod(a::T, b::T, m::T) -> x::T
Compute `(a * b) % m` without intermediate overflow
Precondition ``m > 0``
"""
function mul_mod(a::T, b::T, m::T) where T<:Unsigned
shift = sizeof(T) << 2
m <= one(T) << shift && return ((a%m) * (b%m)) % m
x1, x2 = mul(a, b)
# Invariant: x1 * 2^n + x2 ≡ a * b (mod m)
#answer = widemul(a, b) % m
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer
m1 = m << (leading_zeros(m) + 1)
m_hi = m1 >> shift + oneunit(T) << shift
m_lo = m1 & (oneunit(T) << shift - oneunit(T))
# m_hi and m_lo satisfy the preconditions of subtractand and form a multiple of m.
#@assert leading_zeros(m_hi) == shift-1
#@assert leading_zeros(m_lo) >= shift
#@assert (widen(m_hi) << shift + m_lo) % m == 0
s1, s2 = subtractand(x1, m_hi, m_lo)
x1 -= s1; x1 -= s2 >> shift; (x2, o) = Base.sub_with_overflow(x2, s2 << shift); x1 -= o
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert leading_zeros(x1) ≥ shift - 2 # Status
y = x1 << (shift-2) + x2 >> (shift+2)
s1, s2 = subtractand(y, m_hi, m_lo)
x1 -= s1 >> (shift-2); x1 -= s2 >> (2shift-2); (x2, o) = Base.sub_with_overflow(x2, s1 << (shift+2)); x1 -= o; (x2, o) = Base.sub_with_overflow(x2, s2 << 2); x1 -= o
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert leading_zeros(x1) ≥ 2shift - 4 # Status
y = x1 << (2shift-4) + x2 >> 4
m_lo >>= 1
m_lo += (m_hi & 0x01) << (shift-1)
m_hi >>= 1
q3 = fld(y, m >> 4 + one(T))
s1, s2 = mul(q3, m)
x1 -= s1; x2, o = Base.sub_with_overflow(x2, s2); x1 -= o
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert x1 ≤ 1 # Status
x2 -= m*(x1 == 1 || x2 > m)
#@assert x2 % m == answer # Correct
x2 % m
end
function test_mul_mod(f, types)
function g(::Type{T}) where T
for _ in 1:50_000
a = rand(T)
b = rand(T)
m = rand(T)
m = m == 0 ? one(T) : m
f(a, b, m) === widemul(a, b) % m % typeof(m) || (println((a,b,m));return false)
end
true
end
for i in 1:10
@test all(g, types)
print(i)
end
for i in 1:10
for _ in 1:100
@test g(first(types))
end
print(i)
end
end
@testset "mul_mod" begin
test_mul_mod(mul_mod, (UInt16, UInt32, UInt64, UInt128))
end
function wiki_mul_mod(a::UInt128, b::UInt128, m::Integer)
magic1 = (UInt128(0xFFFFFFFF) << 32)
magic2 = UInt128(0x8000000000000000)
if iszero(((a | b) & magic1))
return (a * b) % m
end
d = zero(UInt128)
mp2 = m >> 1
if a >= m; a %= m; end
if b >= m; b %= m; end
for _ in 1:64
(d > mp2) ? d = ((d << 1) - m) : d = (d << 1)
if !iszero(a & magic2)
d += b
end
if d >= m
d -= m
end
a <<= 1
end
return d
end
function wiki_mul_mod(a::Integer, b::Integer, m::Integer)
sa, sb = UInt128.(unsigned.((a,b)))
return wiki_mul_mod(sa, sb, m)
end
# @testset "wiki_mul_mod" begin [FAILS]
# test_mul_mod(wiki_mul_mod, (UInt128,))
# end
wide_mul_mod(a, b, m) = widemul(a, b) % m % typeof(m)
@testset "wide_mul_mod" begin
test_mul_mod(wide_mul_mod, (UInt8, UInt16, UInt32, UInt64, UInt128))
end
for T in (UInt16, UInt32, UInt64, UInt128)
println(T)
for f in (mul_mod, wide_mul_mod, wiki_mul_mod)
print(lpad(f, 12))
#@btime $f(a, b, c) setup=(a=rand($T); b=rand($T); c=rand($T))
b = @benchmark $f(a, b, c) setup=(a=rand($T); b=rand($T); c=rand($T))
med = median(b) # Median rather than min to avoid timing the small input fastpath
println(" ", BenchmarkTools.prettytime(med.time),
" (", med.allocs , " allocation", med.allocs == 1 ? "" : "s", ": ",
BenchmarkTools.prettymemory(med.memory), ")")
end
end
It’s not as fast as the (broken) Wikipedia implementation on UInt128s, nor fast enough for #47577, but still pretty quick. Runtime in the 128-bit case is dominated by two challenging divisions of 128-bit numbers by 64-bit numbers. My implementation is listed as mul_mod
and widemul(a, b) % m % typeof(m)
is listed as wide_mul_mod
.
UInt16
mul_mod 33.806 ns (0 allocations: 0 bytes)
wide_mul_mod 5.940 ns (0 allocations: 0 bytes)
wiki_mul_mod 23.164 ns (0 allocations: 0 bytes)
UInt32
mul_mod 30.852 ns (0 allocations: 0 bytes)
wide_mul_mod 12.286 ns (0 allocations: 0 bytes)
wiki_mul_mod 23.498 ns (0 allocations: 0 bytes)
UInt64
mul_mod 48.651 ns (0 allocations: 0 bytes)
wide_mul_mod 204.355 ns (0 allocations: 0 bytes)
wiki_mul_mod 218.973 ns (0 allocations: 0 bytes)
UInt128
mul_mod 470.913 ns (0 allocations: 0 bytes)
wide_mul_mod 858.755 ns (15 allocations: 256 bytes)
wiki_mul_mod 234.104 ns (0 allocations: 0 bytes)