How to construct a struct with common field values efficiently

# Cooley-Tukey over C[X]/(Xᴺ+1).
# Based on https://eprint.iacr.org/2016/504
function fft!(a::Vector{ModNum}, Ψ::Vector{ModNum})
    N = length(a)
    m, logkp1, k = 2, trailing_zeros(N) - 1, N >> 1

    @inbounds @simd for j = 1 : k
        t, u = a[j], a[j+k] * Ψ[m+1]
        a[j], a[j+k] = t + u, t - u
    end
    k >>= 1

    while logkp1 > 0
        @inbounds @simd for i = 0 : m-1
            j1 = i << logkp1 + 1; j2 = j1 + k - 1
            @inbounds @simd for j = j1 : j2
                t, u = a[j], a[j+k] * Ψ[m+i+1]
                a[j], a[j+k] = t + u, t - u
            end
        end
        m <<= 1; logkp1 -= 1; k >>= 1
    end
end

# Gentleman-Sande over C[X]/(Xᴺ+1).
# Based on https://eprint.iacr.org/2016/504
function ifft!(a::Vector{ModNum}, Ψinv::Vector{ModNum}) 
    N = length(a)
    m, logkp1, k = N >> 1, 1, 1

    while m > 1
        @inbounds @simd for i = 0 : m - 1
            j1 = i << logkp1 + 1; j2 = j1 + k - 1
            @inbounds @simd for j = j1 : j2
                t, u = a[j], a[j+k]
                a[j], a[j+k] = t + u, Ψinv[m+i+1] * (t - u)
            end
        end
        m >>= 1; logkp1 += 1; k <<= 1
    end

    @inbounds @simd for j = 1 : k
        t, u = a[j], a[j+k]
        a[j], a[j+k] = t + u, Ψinv[m+1] * (t - u)
    end
end

I cannot guarantee the correctness without proper code for root-of-unity generation, but here is my NTT code. I also updated the code above for addition and subtraction between ModNum.

oh this makes perfect sense vector{modnum} is a bad data layout because it wastes half the space.

So I should use Vector{UInt64} for a and \Psi with extra input Q::Modulus? Or possibly define a new data type which stores Vector{UInt64} and Modulus for modulo reduction?

By the way I cannot reproduce the performance loss upon using an immutable struct:

Full code Mutable struct
module A
    mutable struct Modulus
        const Q::UInt64
        const Q⁻¹::UInt64
        const R⁻¹::UInt128
        const mask::UInt128
        const logR::Int64

        function Modulus(Q::Integer)
            @assert isodd(Q) "Modulus should be an odd number."

            Q⁻¹ = UInt64(invmod(-Q, 0x00000000000000010000000000000000))
            R⁻¹ = UInt128(invmod(0x00000000000000010000000000000000, Q))
            new(Q, Q⁻¹, R⁻¹, 0x0000000000000000ffffffffffffffff, 64)
        end
    end

    Base.mod(x::Integer, Q::Modulus) = mod(x, Q.Q)
    Base.invmod(x::Integer, Q::Modulus) = invmod(UInt64(x), Q.Q)

    struct ModNum <: Unsigned
        val::UInt64
        Q::Modulus
    end

    mform(x::ModNum) = ModNum(UInt64(mod(UInt128(mod(x.val, x.Q)) << x.Q.logR, x.Q)), x.Q)
    imform(x::ModNum) = ModNum(UInt64(mod(x.val * x.Q.R⁻¹, x.Q)), x.Q)

    mul(a::UInt64, b::UInt64, Q::Modulus) = begin
        q, q⁻¹ = Q.Q, Q.Q⁻¹

        ab = UInt128(a) * UInt128(b)
        w = UInt64((ab + q * UInt128(UInt64(ab & Q.mask) * q⁻¹)) >> Q.logR)
        w ≥ q ? w - q : w
    end

    Base.:*(a::ModNum, b::ModNum) = begin
        @assert a.Q.Q == b.Q.Q
        ModNum(mul(a.val, b.val, a.Q), a.Q)
    end

    function f(n, Q)
        a = ModNum(1,Q)
        for i in 1:n
            a *= ModNum(i,Q)
        end
        a
    end
end
julia> Q = A.Modulus(113);

julia> @btime A.f(1000, $Q);
  6.239 μs (0 allocations: 0 bytes)

while

immutable
    struct Modulus
        Q::UInt64
        Q⁻¹::UInt64
        R⁻¹::UInt128
        mask::UInt128
        logR::Int64

        function Modulus(Q::Integer)
            @assert isodd(Q) "Modulus should be an odd number."

            Q⁻¹ = UInt64(invmod(-Q, 0x00000000000000010000000000000000))
            R⁻¹ = UInt128(invmod(0x00000000000000010000000000000000, Q))
            new(Q, Q⁻¹, R⁻¹, 0x0000000000000000ffffffffffffffff, 64)
        end
    end

gives

WARNING: replacing module A.
julia> Q = A.Modulus(113);

julia> @btime A.f(1000, $Q);
  6.232 μs (0 allocations: 0 bytes)

weirdly,

mutable and no const
    mutable struct Modulus
        Q::UInt64
        Q⁻¹::UInt64
        R⁻¹::UInt128
        mask::UInt128
        logR::Int64

        function Modulus(Q::Integer)
            @assert isodd(Q) "Modulus should be an odd number."

            Q⁻¹ = UInt64(invmod(-Q, 0x00000000000000010000000000000000))
            R⁻¹ = UInt128(invmod(0x00000000000000010000000000000000, Q))
            new(Q, Q⁻¹, R⁻¹, 0x0000000000000000ffffffffffffffff, 64)
        end
    end

is even a bit faster

WARNING: replacing module A.
julia> Q = A.Modulus(113);

julia> @btime A.f(1000, $Q);
  6.132 μs (0 allocations: 0 bytes)

I may have misled everybody, sorry. That performance loss comes from the NTT actually, and weirdly enough, the performance loss does not happen in the REPL, but in written code only.

  27.458 μs (0 allocations: 0 bytes)
  27.916 μs (0 allocations: 0 bytes)

This is NTT test with dimension N=1024, with immutable Modulus.

  21.542 μs (0 allocations: 0 bytes)
  25.833 μs (0 allocations: 0 bytes)

This is NTT test with dimension N=1024, with mutable Modulus.

I’m interested, but I cannot run your fft! function, probably because I don’t know much about how it is supposed to be used:

Q = A.Modulus(113);
a = [A.ModNum(1,Q) for i in 1:10];
Ψ = [A.ModNum(3,Q) for i in 1:10];
ERROR: promotion of types Main.A.ModNum and Main.A.ModNum failed to change any arguments

Can you give an example how to benchmark?

Q = Modulus(1088516530177)
a = [ModNum(1, Q) for _ = 1 : 1024];
Ψ = [ModNum(1, Q) for _ = 1 : 1024];
@btime fft!(a, Ψ)
@btime ifft!(a, Ψ)

This should work.

1 Like

Ah, looks like I was missing a few functions, now I can reproduce the strange performance difference:


WARNING: replacing module A.
[ Info: This the mutable const struct

  57.214 μs (0 allocations: 0 bytes)
WARNING: replacing module A.
[ Info: This the immutable struct

  62.827 μs (0 allocations: 0 bytes)

I also thought that immutable struct should be faster, hope someone can actually explain why! Maybe it depends on which machine you are using, it makes more performance improvement in my laptop and linux server.

This is exactly what we would expect. With the immutable struct, the ModNum has 5 fields instead of 2, so Vector{ModNum} is an extra 2x worse. All in all, I think it probably makes sense to have an interface that requires explicitly passing the Modulus around (especially because your operations only make sense with one of them).

1 Like

Thanks a lot! It’s a shame that I cannot take advantage from the operator overloading for ModNum. At least there are well-known optimisation techniques that I can make use of if I use UInt64 arrays. I really appreciate your help and thank everyone involved in the discussion again.

What if you use the immutable structure, but make the Modulus a reference instead of the actual value?

1 Like

I also tried that, but such adjustment made the code a little bit slower.