Performance advice needed

Actually you can get additional boost by exploiting the fact, that q is fermat number. This is little more involved, of course. Idea is the following:

  1. We need to find a mod q where a < q^2 and q = 2^2^p + 1 where p some integer.
  2. a can be presented in the form a = n*q + a', here a' = a mod q and n some integer.
  3. a can be further presented in the form a = n*q + a' = n*(q - 1) + n + a', so n + a' = a mod (q - 1) = a & 2^2^p - 1, here we used the fact that mod 2^l in binary representation is just logical and with 2^j - 1.
  4. n can be found as a div q - 1 = a div 2^2^p = a >>> 2^p where we have used the fact that in binary representation division by power of 2 equals to corresponding shift.

Together it gives us this nice function

function fermat_mod(a::T, q::T) where T <: Integer
    x = a & (q - T(2)) - a >>> trailing_zeros(q - T(1)) + q
    x = x >= q ? x - q : x
end

in this function it is assumed that q is fermat and a < q^2.

Adding it into fnt! function yields us

function fnt2!(x::Array{T, 1}, g::T, q::T) where {T<:Integer}
    N = length(x)
    @assert ispow2(N)
    @assert isfermat(q)

    radix2sort!(x)

    logN = log2(N) |> ceil |> T

    for M in 2 .^ [0:logN-1;] # TODO: not very readable
        interval = 2M
        p = div(N, interval)
        gp = powermod(g, p, q)
        W = 1
        for m in 1:M
            for i in m:interval:N
               j = i + M
               Wxj = W * x[j]
               Wxj = Wxj & (q - T(2)) - Wxj >>> trailing_zeros(q - T(1)) + q
               Wxj = Wxj >= q ? Wxj + q : Wxj

               x[i], x[j] = x[i] + Wxj, x[i] - Wxj + q
               x[i] = x[i] >= q ? x[i] - q : x[i]
               x[j] = x[j] >= q ? x[j] - q : x[j]
            end
            W = mod(W*gp, q)
        end
    end

    return x
end

And benchmark

@btime fnt2($x, $169, $65537)  # 176.820 μs (4 allocations: 32.42 KiB) 
7 Likes