I am working on a function with calculation that have structure similar to FFT. I’m a bit disappointed, cause I hoped it will be 10x slower than FFTW but benchmarks show its 100x slower. i tried to put @simd and @inbounds but it have minimal effect. Could you tell me if there is something obviously wrong? I don’t ask about algorithm analysis, but just the code quality. I am not interested in adding multithreading, cause this package is just a part of a bigger thing, and parallelisation will be potentially done on a higher level.
This line is doubly inefficient: it first constructs 0:logN, which is cheap and fine, but then it collects that into a new vector, then it computes an entire new vector of 2 ^ .... That’s two extra vector allocations that you don’t need.
Instead, you could just iterate with something like:
@tkf is right, powermod should be avoided. Mod itself is rather costly operation, and powermod is even more expensive.
You can remove powermod from inner loop and iteratively calculate it, which gives the necessary speed up.
function fnt2!(x::Array{T, 1}, g::T, q::T) where {T<:Integer}
N = length(x)
@assert ispow2(N)
@assert isfermat(q)
radix2sort!(x)
logN = log2(N) |> ceil |> T
for M in 2 .^ [0:logN-1;] # TODO: not very readable
interval = 2M
p = div(N, interval)
gp = powermod(g, p, q)
W = 1
for m in 1:M
for i in m:interval:N
j = i + M
Wxj = W * x[j]
x[i], x[j] = x[i] + Wxj, x[i] - Wxj
x[i] = mod(x[i], q)
x[j] = mod(x[j], q)
end
W = mod(W*gp, q)
end
end
return x
end
function fnt2(x::Array{T}, g::T, q::T) where {T<:Integer}
return fnt2!(copy(x), g, q)
end
# Sanity check
const x = mod.(rand(Int, 4096), 65537);
all(fnt2(x, 169, 65537) .== fnt(x, 169, 65537)) # true
@btime fnt($x, $225, $65537) # 6.923 ms (4 allocations: 32.42 KiB)
@btime fnt2($x, $225, $65537) # 431.841 μs (4 allocations: 32.42 KiB)
You can get additional speed up by turning 2 mod operations in the inner loop to one
function fnt3!(x::Array{T, 1}, g::T, q::T) where {T<:Integer}
N = length(x)
@assert ispow2(N)
@assert isfermat(q)
radix2sort!(x)
logN = log2(N) |> ceil |> T
for M in 2 .^ [0:logN-1;] # TODO: not very readable
interval = 2M
p = div(N, interval)
gp = powermod(g, p, q)
W = 1
for m in 1:M
for i in m:interval:N
j = i + M
Wxj = mod(W * x[j], q)
x[i], x[j] = x[i] + Wxj, x[i] - Wxj + q
x[i] = x[i] >= q ? x[i] - q : x[i]
x[j] = x[j] >= q ? x[j] - q : x[j]
end
W = mod(W*gp, q)
end
end
return x
end
function fnt3(x::Array{T}, g::T, q::T) where {T<:Integer}
return fnt3!(copy(x), g, q)
end
# Sanity check
const x = mod.(rand(Int, 4096), 65537)
all(fnt3(x, 169, 65537) .== fnt(x, 169, 65537)) # true
@btime fnt3($x, $169, $65537) # 245.112 μs (4 allocations: 32.42 KiB)
EDIT: Changed slightly definition of x[j], it doesn’t affect performance, but initial version introduced bug for unsigned integers, because
Actually you can get additional boost by exploiting the fact, that q is fermat number. This is little more involved, of course. Idea is the following:
We need to find a mod q where a < q^2 and q = 2^2^p + 1 where p some integer.
a can be presented in the form a = n*q + a', here a' = a mod q and n some integer.
a can be further presented in the form a = n*q + a' = n*(q - 1) + n + a', so n + a' = a mod (q - 1) = a & 2^2^p - 1, here we used the fact that mod 2^l in binary representation is just logical and with 2^j - 1.
n can be found as a div q - 1 = a div 2^2^p = a >>> 2^p where we have used the fact that in binary representation division by power of 2 equals to corresponding shift.
Together it gives us this nice function
function fermat_mod(a::T, q::T) where T <: Integer
x = a & (q - T(2)) - a >>> trailing_zeros(q - T(1)) + q
x = x >= q ? x - q : x
end
in this function it is assumed that q is fermat and a < q^2.
Adding it into fnt! function yields us
function fnt2!(x::Array{T, 1}, g::T, q::T) where {T<:Integer}
N = length(x)
@assert ispow2(N)
@assert isfermat(q)
radix2sort!(x)
logN = log2(N) |> ceil |> T
for M in 2 .^ [0:logN-1;] # TODO: not very readable
interval = 2M
p = div(N, interval)
gp = powermod(g, p, q)
W = 1
for m in 1:M
for i in m:interval:N
j = i + M
Wxj = W * x[j]
Wxj = Wxj & (q - T(2)) - Wxj >>> trailing_zeros(q - T(1)) + q
Wxj = Wxj >= q ? Wxj + q : Wxj
x[i], x[j] = x[i] + Wxj, x[i] - Wxj + q
x[i] = x[i] >= q ? x[i] - q : x[i]
x[j] = x[j] >= q ? x[j] - q : x[j]
end
W = mod(W*gp, q)
end
end
return x
end
Sorry for the confusion, I was typing on my phone while tired. I’m having a slightly hard time explaining what I mean, but https://gmplib.org/~tege/divcnst-pldi94.pdf does a good job of it. The TLDR is the group I was talking about was the Int64s under multiplication. Since modulo can be computed cheaply once you have division, this should speed up the repeated mod q operations.
Thanks for explaining, I gave FFT just as an example of algorithm that should have similar complexity to see how slow we are with FNT. I benchmarked this whole plan creation and preallocation and it’s 211 μs, so I understand it will be faster if we have a lot of vectors to transform, but surely you are right here.
I added again suggestions from @tkf, so now fnt() runs 382 μs (451 μs was before). I think this is amazing result for pure Julia implementation.
BTW, are you aware of any pure Julia FFT implementations? Maybe I should compare timings also with them.
Thank you for sharing, it looks very interesting, but @Skoffer eliminated all mod operations, except W = mod(W * gp, q), which is executed just few times. I will try to use it anyway on the original source I posted.
UPDATE: as I promissed I did testing with those multiplicative inverses. I did also some further improvements which gave:
Also this multiplicative inverse works just for machine word sized integers. This FNT approach is really only useful with BigInts to get a very high precision in number representation so this is a deal breaker for me.