# Cooley-Tukey over C[X]/(Xᴺ+1).
# Based on https://eprint.iacr.org/2016/504
function fft!(a::Vector{ModNum}, Ψ::Vector{ModNum})
N = length(a)
m, logkp1, k = 2, trailing_zeros(N) - 1, N >> 1
@inbounds @simd for j = 1 : k
t, u = a[j], a[j+k] * Ψ[m+1]
a[j], a[j+k] = t + u, t - u
end
k >>= 1
while logkp1 > 0
@inbounds @simd for i = 0 : m-1
j1 = i << logkp1 + 1; j2 = j1 + k - 1
@inbounds @simd for j = j1 : j2
t, u = a[j], a[j+k] * Ψ[m+i+1]
a[j], a[j+k] = t + u, t - u
end
end
m <<= 1; logkp1 -= 1; k >>= 1
end
end
# Gentleman-Sande over C[X]/(Xᴺ+1).
# Based on https://eprint.iacr.org/2016/504
function ifft!(a::Vector{ModNum}, Ψinv::Vector{ModNum})
N = length(a)
m, logkp1, k = N >> 1, 1, 1
while m > 1
@inbounds @simd for i = 0 : m - 1
j1 = i << logkp1 + 1; j2 = j1 + k - 1
@inbounds @simd for j = j1 : j2
t, u = a[j], a[j+k]
a[j], a[j+k] = t + u, Ψinv[m+i+1] * (t - u)
end
end
m >>= 1; logkp1 += 1; k <<= 1
end
@inbounds @simd for j = 1 : k
t, u = a[j], a[j+k]
a[j], a[j+k] = t + u, Ψinv[m+1] * (t - u)
end
end
I cannot guarantee the correctness without proper code for root-of-unity generation, but here is my NTT code. I also updated the code above for addition and subtraction between ModNum.
So I should use Vector{UInt64} for a and \Psi with extra input Q::Modulus? Or possibly define a new data type which stores Vector{UInt64} and Modulus for modulo reduction?
By the way I cannot reproduce the performance loss upon using an immutable struct:
Full code Mutable struct
module A
mutable struct Modulus
const Q::UInt64
const Q⁻¹::UInt64
const R⁻¹::UInt128
const mask::UInt128
const logR::Int64
function Modulus(Q::Integer)
@assert isodd(Q) "Modulus should be an odd number."
Q⁻¹ = UInt64(invmod(-Q, 0x00000000000000010000000000000000))
R⁻¹ = UInt128(invmod(0x00000000000000010000000000000000, Q))
new(Q, Q⁻¹, R⁻¹, 0x0000000000000000ffffffffffffffff, 64)
end
end
Base.mod(x::Integer, Q::Modulus) = mod(x, Q.Q)
Base.invmod(x::Integer, Q::Modulus) = invmod(UInt64(x), Q.Q)
struct ModNum <: Unsigned
val::UInt64
Q::Modulus
end
mform(x::ModNum) = ModNum(UInt64(mod(UInt128(mod(x.val, x.Q)) << x.Q.logR, x.Q)), x.Q)
imform(x::ModNum) = ModNum(UInt64(mod(x.val * x.Q.R⁻¹, x.Q)), x.Q)
mul(a::UInt64, b::UInt64, Q::Modulus) = begin
q, q⁻¹ = Q.Q, Q.Q⁻¹
ab = UInt128(a) * UInt128(b)
w = UInt64((ab + q * UInt128(UInt64(ab & Q.mask) * q⁻¹)) >> Q.logR)
w ≥ q ? w - q : w
end
Base.:*(a::ModNum, b::ModNum) = begin
@assert a.Q.Q == b.Q.Q
ModNum(mul(a.val, b.val, a.Q), a.Q)
end
function f(n, Q)
a = ModNum(1,Q)
for i in 1:n
a *= ModNum(i,Q)
end
a
end
end
I may have misled everybody, sorry. That performance loss comes from the NTT actually, and weirdly enough, the performance loss does not happen in the REPL, but in written code only.
I’m interested, but I cannot run your fft! function, probably because I don’t know much about how it is supposed to be used:
Q = A.Modulus(113);
a = [A.ModNum(1,Q) for i in 1:10];
Ψ = [A.ModNum(3,Q) for i in 1:10];
ERROR: promotion of types Main.A.ModNum and Main.A.ModNum failed to change any arguments
I also thought that immutable struct should be faster, hope someone can actually explain why! Maybe it depends on which machine you are using, it makes more performance improvement in my laptop and linux server.
This is exactly what we would expect. With the immutable struct, the ModNum has 5 fields instead of 2, so Vector{ModNum} is an extra 2x worse. All in all, I think it probably makes sense to have an interface that requires explicitly passing the Modulus around (especially because your operations only make sense with one of them).
Thanks a lot! It’s a shame that I cannot take advantage from the operator overloading for ModNum. At least there are well-known optimisation techniques that I can make use of if I use UInt64 arrays. I really appreciate your help and thank everyone involved in the discussion again.