Here’s a version for UInt128
which runs in 100ns on my computer (vs 122ns for lilith’s version and 350ns for the wiki version). Looking back, I think it is roughly the same algorithm as lilith’s.
Profiling reveals that almost all the time is spent evenly in the 8 division operations (mod
, fldmod
(counts for 2) and sub_mod
). I’m not sure if you can do it in fewer divisions.
Code
using BitIntegers: UInt256
using BenchmarkTools: @benchmark
function mul_full(x::UInt128, y::UInt128)
z = UInt256(x) * UInt256(y)
return (mod(z, UInt128), mod(z >> 128, UInt128))
end
function sub_mod(x::UInt128, y::UInt128, m::UInt128)
if x > y
return mod(x - y, m)
else
return m - mod1(y - x, m)
end
end
function mul_mod(x::UInt128, y::UInt128, m::UInt128)
# fast branch: no overflow
if iszero(x >> 64) && iszero(y >> 64)
return mod(x * y, m)
end
# fast branch: no overflow after reducing mod m
if iszero(m >> 64)
return mod(mod(x, m) * mod(y, m), m)
end
# m in base 2^64
# m = m0 + m1 * u
# where u = 2^64
m0 = mod(mod(m, UInt64), UInt128)
m1 = m >> 64
# z=x*y in base 2^128
# z == z0 + u^2 * z1
z0, z1 = mul_full(x, y)
# require z1 < m
# so that a1 <= 2^64
z1 = mod(z1, m)
# reduce u * z1 by dividing by m1
# and use the fact that u * m1 == -m0 (mod m)
# so z == z0 + u * z2 (mod m)
a1, a0 = fldmod(mod(z1, m), m1)
z2 = sub_mod(a0 << 64, a1 * m0, m)
# reduce u * z2 in the same way
# so z == z0 - z3 (mod m)
b1, b0 = fldmod(z2, m1)
z3 = sub_mod(b1 * m0, b0 << 64, m)
# final sub_mod gets the answer
return sub_mod(z0, z3, m)
end
function mul_mod_big(x::UInt128, y::UInt128, m::UInt128)
return convert(UInt128, mod(big(x) * big(y), big(m)))
end
function mul_mod_bad(x::UInt128, y::UInt128, m::UInt128)
return mod(x * y, m)
end
function mul_mod_256(x::UInt128, y::UInt128, m::UInt128)
return convert(UInt128, mod(UInt256(x) * UInt256(y), UInt256(m)))
end
function check(x, y, m)
z0 = mul_mod(x, y, m)
z1 = mul_mod_big(x, y, m)
@assert z0 == z1
end
function check_rand(n=1_000_000)
for _ in 1:n
check(rand(UInt128), rand(UInt128), rand(UInt128))
end
end
function benchmark(f=mul_mod)
@benchmark $f(x, y, m) setup=(x=rand(UInt128); y=rand(UInt128); m=rand(UInt128))
end