How about expressing (a*b)%c
in base 2^64
but perform the calculations with 128-bit precision so you don’t need to worry about overflow?
a
, b
, and c
may be Int128
s already, and performing calculations in 256-bit precision is slower.
Mind if I steal that and put it in GitHub - JuliaMath/IntegerMathUtils.jl ?
Be careful with code you don’t understand, especially when it is advertised as “completely opaque and unverified”. Here’s an example where it is wrong:
julia> a,b,m = (0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)
(0x557f0d8d9f08cab0f5b1f6e6fe814efe, 0xff3850c2465783f9b39c716f225f7c79, 0x556ad04f016faad7333284c59a3221af)
julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x2f601dd0ad7828c5366d3f1de8757516
julia> d = wiki_mul_mod(a, b, m) # Wrong, idk why.
0x1a5fe3638d5932abb00197a805401041
Here’s another, much longer, slightly slower, fairly opaque, and reasonably well verified approach I crafted:
using Test, BenchmarkTools
"""
mul(a::T, b::T) -> x1::T, x1::unsigned(T)
Multiply `a` and `b` and return the high and low parts of the result.
``a*b = x1*2^n + x2`` where ``n`` is the number of bits in `T`.
"""
function mul(a::T, b::T) where T<:Integer
shift = sizeof(T) << 2
lo_mask = one(T) << shift - one(T)
a_lo = a & lo_mask
b_lo = b & lo_mask
a_hi = a >> shift
b_hi = b >> shift
x2 = unsigned(a_lo * b_lo)
m1 = a_lo * b_hi
m2 = a_hi * b_lo
x2, o1 = Base.add_with_overflow(x2, unsigned(m1) << shift)
x2, o2 = Base.add_with_overflow(x2, unsigned(m2) << shift)
x1 = a_hi * b_hi
x1 += m1 >> shift
x1 += m2 >> shift
x1 += o1
x1 += o2
x1, x2
end
@testset "mul" begin
for T in (UInt8, Int8), a in typemin(T):typemax(T), b in typemin(T):typemax(T)
x1, x2 = mul(a, b)
x = Int(a) * Int(b)
@test x1 * 256^sizeof(T) + x2 == x
end
end
"""
subtractand(x::T, m_hi::T, m_lo::T) -> x1::T, x2::T
Guess a `q` and compute `x1 = m_hi*q` and `x2 = m_lo*q`.
Precondition: ``leading_zeros(m_hi) = n-1`` and ``n ≤ leading_zeros(m_lo)``
Postcondition: ``0 ≤ x - y < 4*2^n`` and ``m_hi * 2^n + m_lo | y``
Where ``y = x1 * 2^n + x2`` and ``n`` is the half the number of bits in `T`.
!!! Warning
this docstring might contain lies.
"""
function subtractand(x::T, m_hi::T, m_lo::T) where T<:Unsigned
shift = sizeof(T) << 2
q = fld(x, m_hi + oneunit(T))
x1 = m_hi * q
x2 = m_lo * q
x1, x2
end
@testset "subtractand" begin
for T in (UInt8,), x in typemin(T):typemax(T), m_hi in T(0x10):T(0x1f), m_lo in T(0x00):T(0x0f)
x1, x2 = subtractand(x, m_hi, m_lo)
y = x1 * 16^sizeof(T) + x2
@test 0 <= x * 16^sizeof(T) - y < 4*256^sizeof(T)
@test y % (m_hi * 16^sizeof(T) + m_lo) == 0
end
x1, x2 = subtractand(0x57c2, 0x013d, 0x0066)
y = x1 * 16^sizeof(x1) + x2
@test 0 <= 0x57c2 * 16^sizeof(x1) - y < 4*256^sizeof(x1)
@test y % (0x013d * 16^sizeof(x1) + 0x0066) == 0
end
"""
mul_mod(a::T, b::T, m::T) -> x::T
Compute `(a * b) % m` without intermediate overflow
Precondition ``m > 0``
"""
function mul_mod(a::T, b::T, m::T) where T<:Unsigned
shift = sizeof(T) << 2
m <= one(T) << shift && return ((a%m) * (b%m)) % m
x1, x2 = mul(a, b)
# Invariant: x1 * 2^n + x2 ≡ a * b (mod m)
#answer = widemul(a, b) % m
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer
m1 = m << (leading_zeros(m) + 1)
m_hi = m1 >> shift + oneunit(T) << shift
m_lo = m1 & (oneunit(T) << shift - oneunit(T))
# m_hi and m_lo satisfy the preconditions of subtractand and form a multiple of m.
#@assert leading_zeros(m_hi) == shift-1
#@assert leading_zeros(m_lo) >= shift
#@assert (widen(m_hi) << shift + m_lo) % m == 0
s1, s2 = subtractand(x1, m_hi, m_lo)
x1 -= s1; x1 -= s2 >> shift; (x2, o) = Base.sub_with_overflow(x2, s2 << shift); x1 -= o
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert leading_zeros(x1) ≥ shift - 2 # Status
y = x1 << (shift-2) + x2 >> (shift+2)
s1, s2 = subtractand(y, m_hi, m_lo)
x1 -= s1 >> (shift-2); x1 -= s2 >> (2shift-2); (x2, o) = Base.sub_with_overflow(x2, s1 << (shift+2)); x1 -= o; (x2, o) = Base.sub_with_overflow(x2, s2 << 2); x1 -= o
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert leading_zeros(x1) ≥ 2shift - 4 # Status
y = x1 << (2shift-4) + x2 >> 4
m_lo >>= 1
m_lo += (m_hi & 0x01) << (shift-1)
m_hi >>= 1
q3 = fld(y, m >> 4 + one(T))
s1, s2 = mul(q3, m)
x1 -= s1; x2, o = Base.sub_with_overflow(x2, s2); x1 -= o
#@assert (widen(x1) << 8sizeof(T) + x2) % m == answer # Correct
#@assert x1 ≤ 1 # Status
x2 -= m*(x1 == 1 || x2 > m)
#@assert x2 % m == answer # Correct
x2 % m
end
function test_mul_mod(f, types)
function g(::Type{T}) where T
for _ in 1:50_000
a = rand(T)
b = rand(T)
m = rand(T)
m = m == 0 ? one(T) : m
f(a, b, m) === widemul(a, b) % m % typeof(m) || (println((a,b,m));return false)
end
true
end
for i in 1:10
@test all(g, types)
print(i)
end
for i in 1:10
for _ in 1:100
@test g(first(types))
end
print(i)
end
end
@testset "mul_mod" begin
test_mul_mod(mul_mod, (UInt16, UInt32, UInt64, UInt128))
end
function wiki_mul_mod(a::UInt128, b::UInt128, m::Integer)
magic1 = (UInt128(0xFFFFFFFF) << 32)
magic2 = UInt128(0x8000000000000000)
if iszero(((a | b) & magic1))
return (a * b) % m
end
d = zero(UInt128)
mp2 = m >> 1
if a >= m; a %= m; end
if b >= m; b %= m; end
for _ in 1:64
(d > mp2) ? d = ((d << 1) - m) : d = (d << 1)
if !iszero(a & magic2)
d += b
end
if d >= m
d -= m
end
a <<= 1
end
return d
end
function wiki_mul_mod(a::Integer, b::Integer, m::Integer)
sa, sb = UInt128.(unsigned.((a,b)))
return wiki_mul_mod(sa, sb, m)
end
# @testset "wiki_mul_mod" begin [FAILS]
# test_mul_mod(wiki_mul_mod, (UInt128,))
# end
wide_mul_mod(a, b, m) = widemul(a, b) % m % typeof(m)
@testset "wide_mul_mod" begin
test_mul_mod(wide_mul_mod, (UInt8, UInt16, UInt32, UInt64, UInt128))
end
for T in (UInt16, UInt32, UInt64, UInt128)
println(T)
for f in (mul_mod, wide_mul_mod, wiki_mul_mod)
print(lpad(f, 12))
#@btime $f(a, b, c) setup=(a=rand($T); b=rand($T); c=rand($T))
b = @benchmark $f(a, b, c) setup=(a=rand($T); b=rand($T); c=rand($T))
med = median(b) # Median rather than min to avoid timing the small input fastpath
println(" ", BenchmarkTools.prettytime(med.time),
" (", med.allocs , " allocation", med.allocs == 1 ? "" : "s", ": ",
BenchmarkTools.prettymemory(med.memory), ")")
end
end
It’s not as fast as the (broken) Wikipedia implementation on UInt128s, nor fast enough for #47577, but still pretty quick. Runtime in the 128-bit case is dominated by two challenging divisions of 128-bit numbers by 64-bit numbers. My implementation is listed as mul_mod
and widemul(a, b) % m % typeof(m)
is listed as wide_mul_mod
.
UInt16
mul_mod 33.806 ns (0 allocations: 0 bytes)
wide_mul_mod 5.940 ns (0 allocations: 0 bytes)
wiki_mul_mod 23.164 ns (0 allocations: 0 bytes)
UInt32
mul_mod 30.852 ns (0 allocations: 0 bytes)
wide_mul_mod 12.286 ns (0 allocations: 0 bytes)
wiki_mul_mod 23.498 ns (0 allocations: 0 bytes)
UInt64
mul_mod 48.651 ns (0 allocations: 0 bytes)
wide_mul_mod 204.355 ns (0 allocations: 0 bytes)
wiki_mul_mod 218.973 ns (0 allocations: 0 bytes)
UInt128
mul_mod 470.913 ns (0 allocations: 0 bytes)
wide_mul_mod 858.755 ns (15 allocations: 256 bytes)
wiki_mul_mod 234.104 ns (0 allocations: 0 bytes)
This looks wrong. In the original code there was a loop for (i = 0; i < 64; ++i)
, but that was for uint64_t
, whereas you are working with UInt128
. If you look at the stackoverflow post that I linked above, you’ll see that it is describing a variant of long multiplication in binary which operates on the bits a
(of which there are now 128).
I would try adapting the stackoverflow code, which is a lot better explained and seems like it should have similar or better performance.
And the magic
numbers need to be upgraded also:
magic1 = (UInt128(0xFFFFFFFFFFFFFFFF) << 64)
magic2 = UInt128(0x80000000000000000000000000000000)
Testing with the (a,b,m)
given above with 128 round loop gives the right answer.
I would try adapting the stackoverflow code, which is a lot better explained and seems like it should have similar or better performance.
I copied and pasted @jd-forester’s code verbatim to perform benchmarks and tests and devised a different algorithm from scratch; I did not adapt any existing code.
Changing the 64 to 128 as @stevengj suggested and the magic numbers as @Dan suggested yields the counterexample (0x350caac24c5180855a2891ec9a655250, 0x5d043f9d0c547441717da21f58a17d54, 0xac4dd9d313c605d3e9c794646437b1b9)
, which I expect will stand as the wiki page notes that m must have fewer than 64 bits and I imagine that carries forward to m must have fewer than 128 bits once extended.
In addition to still not being correct over the whole domain, I benchmark the revised version as comparable in speed to the correct (I think) version I posted.
mul_mod wide_mul_mod steve_dan_wiki_mul_mod (still broken)
UInt16 43.684 ns 8.239 ns 31.100 ns
UInt32 40.280 ns 14.272 ns 30.405 ns
UInt64 60.862 ns 268.941 ns 274.094 ns
UInt128 568.878 ns 1.000 μs 560.923 ns
My version does not include a loop over every binary digit and as far as I can tell, all the linked implementations do (except the extended precision floating point trick).
Managed to get wiki_mul_mod
to not break over new test-case. The updated code is:
Updated `wiki_mul_mod`
function wiki_mul_mod(a::UInt128, b::UInt128, m::UInt128)
magic1 = (UInt128(0xFFFFFFFFFFFFFFFF) << 64)
magic2 = UInt128(0x80000000000000000000000000000000)
if iszero(((a | b) & magic1))
return (a * b) % m
end
d = zero(UInt128)
mp2 = m >> 1
if a >= m
a %= m
end
if b >= m
b %= m
end
mc = (typemax(UInt128) - m) + 1
for t = 1:128
d2 =
d <= (typemax(UInt128) >> 1) ? d << 1 :
typemax(UInt128) - ((typemax(UInt128) - d + 1) << 1) + 1 + mc
d = (d2 >= m) ? (d2 - m) : d2
if !iszero(a & magic2)
if b > (typemax(UInt128) - d)
d += b
d >= m && (d -= m)
d += mc
else
d += b
end
end
d >= m && (d -= m)
a <<= 1
end
return d
end
It gives for test-case:
julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x659f2a534b3977bb8841b239b086c2b4
julia> d = wiki_mul_mod(a, b, m)
0x659f2a534b3977bb8841b239b086c2b4
julia> d = wiki_mul_mod(b, a, m)
0x659f2a534b3977bb8841b239b086c2b4
In any case, these sort of primitives should be formally verified against an arithmetic model (using SMT). Because bugs and edge cases can be hard to detect.
A cursory check with my fuzzer seems to agree:
julia> gen = PropCheck.tuple(3, igen(UInt128));
julia> function prop(a,b,c)
wiki_mul_mod(a,b,c) == ((big(a) * big(b)) % big(c))
end
prop (generic function with 1 method)
julia> check(Splat(prop), gen)
true
for comparison, it fails almost immediately for earlier versions:
julia> function prop2(a,b,c)
mul_mod(a,b,c) == ((big(a) * big(b)) % big(c))
end
prop2 (generic function with 1 method)
julia> check(Splat(prop2), gen)
[ Info: Found counterexample for 'Splat(prop2)', beginning shrinking...
┌ Info: 75 counterexamples found, of which 3 threw 2 distinct exception types
│ Errors =
│ 2-element Vector{Any}:
│ false
└ DivideError()
((0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000), DivideError())
julia> check(Splat(prop2), gen)
[ Info: Found counterexample for 'Splat(prop2)', beginning shrinking...
[ Info: 134 counterexamples found
(0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001)
To be fair, the DivideError
also happens with the updated versions:
julia> prop(rand(UInt128), rand(UInt128), zero(UInt128))
ERROR: DivideError: integer division error
Stacktrace:
[1] rem
@ ./int.jl:1033 [inlined]
[2] wiki_mul_mod(a::UInt128, b::UInt128, m::UInt128)
@ Main ./REPL[2]:14
[3] prop(a::UInt128, b::UInt128, c::UInt128)
@ Main ./REPL[18]:2
[4] top-level scope
@ REPL[33]:1
so that should be guarded behind some checks and I should finally get around to documenting my fuzzer (+adding more features so it’s actually useable)
Checked this counterexample and it satisfied property. Any more details?
julia> a, b, m = 0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001
(0x00000000000000010000000000000000, 0x00000000000000000000000100000000, 0x00000000000000010000000000000001)
julia> d = wiki_mul_mod(a, b, m)
0x0000000000000000ffffffff00000001
julia> c = UInt128((big(a) * big(b)) % big(m)) # Correct with BigInt conversion
0x0000000000000000ffffffff00000001
julia> c-d
0x00000000000000000000000000000000
UPADTE: I get it… this is counterexample for mul_mod
, not wiki_mul_mod
.
The fuzzer tools looks awesome. Any link? (or perhaps more dev needed).
Which function is this a counterexample for? If you could provide source code or link to the post that contains it that would be great.
In addition, making it general for all type sizes and brute force checking for the smaller ones is a fairly reliable system that is simple and efficient enough to be included in CI. I generalized your code to
function wiki_mul_mod_post_18(a::T, b::T, m::T) where T <: Base.BitUnsigned
magic1 = (one(T) << (sizeof(T)<<2) - one(T)) << (sizeof(T)<<2)
magic2 = typemax(T)÷2+one(T)
if iszero(((a | b) & magic1))
return (a * b) % m
end
d = zero(T)
mp2 = m >> 1
if a >= m
a %= m
end
if b >= m
b %= m
end
mc = (typemax(T) - m) + one(T)
for t = 1:(sizeof(T)<<3)
d2 =
d <= (typemax(T) >> 1) ? d << 1 :
typemax(T) - ((typemax(T) - d + one(T)) << 1) + one(T) + mc
d = (d2 >= m) ? (d2 - m) : d2
if !iszero(a & magic2)
if b > (typemax(T) - d)
d += b
d >= m && (d -= m)
d += mc
else
d += b
end
end
d >= m && (d -= m)
a <<= 1
end
return d
end
and found that it passes brute force tests for all BitUnsigned
types.
It’s also pretty efficient for UInt128!
mul_mod wide_mul_mod wiki_mul_mod_post_18
UInt64 46.574 ns 211.096 ns 136.679 ns
UInt128 437.209 ns 659.184 ns 396.000 ns
It’s a quickcheck/hedgehog inspired property based testing tool! You can find the repo here, though the documentation that is up on github is much too bare. There’s also lots of bugs that need fixing, like not being able to generate ranges of things properly. I encountered that bug while writing more documentation, and haven’t gotten around to fixing it yet
That’s for the original code in this post you mentioned was wrong but didn’t know why. I wanted to give a smaller example, which is (in part) exactly what that fuzzing tool is for
When you can widen a
and b
to another native type, the problem is trivial. Even at max-width types, the problem is trivial when a*b
does not overflow. If a*b
does overflow, you can still use something like a (hi,lo) == hilomul(a,b)
such that a*b == hi*z+lo
with z=typemax(T)+1
(we don’t seem to have this function in Base
? but it’s pretty easy in any case) and then use the identity mod(hi*z+lo,c) == mod(mod(mod(hi,c)*mod(z,c),c) + mod(lo,c),c)
(I just ripped the mod
identities off wikipedia – check me) so long as the intermediate product does not overflow (sufficient that c*c
does not overflow). Obviously mod(z,c)
has to be computed carefully here since z
is unrepresentable, but this isn’t so hard.
It seems that you’d only need to resort to the complicated (and presumably much slower) version when all the above contingencies fail. If we have to fall back to some slow path for c
that are unrepresentable in 64-bits, we’d probably be willing to live with that.
I was going to say that perhaps I’m being dense but doesn’t this definition work well for all bit sizes less that 128:
mulmod(a::Integer, b::Integer, c::Integer) =
mod(widemul(a, b), c) % typeof(c)
The tricky case is how to do this for 128-bit arguments where c
actually is too big to fit in a 64-bit integer.
Also note: regrettably we followed the C definition of %
and it actually does rem
whereas what everyone wants here is mod
.
I agree that we should use native arithmetic (with widening) whenever possible. Although I’m not sure anyone is disputing that.
At the widest native types (eg Int128
) I imagine it’s still worthwhile to try to complete the operation in native arithmetic. A run-time overflow check can slow-path to the algorithm discussed here when required, but the significant speed difference makes the “try native first” approach enticing. I don’t imagine that very large arguments are common in most domains.
Very closely related: factorials and binomial coefficients! Both of those are useful to have modulo some other number, because both overflow very quickly.
Modular variants of those would certainly be useful to have! Not sure if they belong in base or a package though. Maybe ModularMethods.jl
or something like that?
The OP was specifically about UInt128.
Ah yeah, missed that!
I do think there should be a package specifically for modular arithmetic, but I think these two are basic enough they could go in base alongside the current modular power.
You say regrettably–are you thinking of changing it in Julia 2.0? I think it’d be nice for a lot of people coming from Python, at the very least.