Performance advice needed

Hi,

I am working on a function with calculation that have structure similar to FFT. I’m a bit disappointed, cause I hoped it will be 10x slower than FFTW but benchmarks show its 100x slower. i tried to put @simd and @inbounds but it have minimal effect. Could you tell me if there is something obviously wrong? I don’t ask about algorithm analysis, but just the code quality. I am not interested in adding multithreading, cause this package is just a part of a bigger thing, and parallelisation will be potentially done on a higher level.

The code is in fnt! function.

https://github.com/jakubwro/NumberTheoreticTransforms.jl/blob/ed0e99694da5a64a47ff301a9e9c8ea0ae2737d7/src/fnt.jl#L56

Example benchmark code:

using NumberTheoreticTransforms, FFTW

const x = mod.(rand(Int, 4096), 65537);
@btime fnt($x, $169, $65537); # 6.306 ms (4 allocations: 32.42 KiB)
@btime fft($x); #61.354 μs (53 allocations: 131.02 KiB)

const x2 = mod.(rand(Int, 8192), 65537);
@btime fnt($x2, $225, $65537); #15.061 ms (4 allocations: 64.45 KiB)
@btime fft($x2); #130.380 μs (53 allocations: 259.02 KiB)
for M in 2 .^ [0:logN-1;]

This line is doubly inefficient: it first constructs 0:logN, which is cheap and fine, but then it collects that into a new vector, then it computes an entire new vector of 2 ^ .... That’s two extra vector allocations that you don’t need.

Instead, you could just iterate with something like:

for i in 0:logN 
  x = 2^i
  ...
1 Like

No performance change, this vector is really short. But thanks, I didn’t like this line anyway :slight_smile:

Did you try loading and storing x[i] and x[j] only once per iteration (by assigning them to local variables)?

1 Like

Also, can’t powermod be pulled out of the inner loop?

2 Likes

@tkf is right, powermod should be avoided. Mod itself is rather costly operation, and powermod is even more expensive.
You can remove powermod from inner loop and iteratively calculate it, which gives the necessary speed up.

function fnt2!(x::Array{T, 1}, g::T, q::T) where {T<:Integer}
    N = length(x)
    @assert ispow2(N)
    @assert isfermat(q)

    radix2sort!(x)

    logN = log2(N) |> ceil |> T

    for M in 2 .^ [0:logN-1;] # TODO: not very readable
        interval = 2M
        p = div(N, interval)
        gp = powermod(g, p, q)
        W = 1
        for m in 1:M
            for i in m:interval:N
               j = i + M
               Wxj = W * x[j]
               x[i], x[j] = x[i] + Wxj, x[i] - Wxj
               x[i] = mod(x[i], q)
               x[j] = mod(x[j], q)
            end
            W = mod(W*gp, q)
        end
    end

    return x
end

function fnt2(x::Array{T}, g::T, q::T) where {T<:Integer}
    return fnt2!(copy(x), g, q)
end

# Sanity check
const x = mod.(rand(Int, 4096), 65537);
all(fnt2(x, 169, 65537) .== fnt(x, 169, 65537)) # true

@btime fnt($x, $225, $65537) # 6.923 ms (4 allocations: 32.42 KiB)
@btime fnt2($x, $225, $65537) # 431.841 μs (4 allocations: 32.42 KiB)
6 Likes

You can get additional speed up by turning 2 mod operations in the inner loop to one

function fnt3!(x::Array{T, 1}, g::T, q::T) where {T<:Integer}
    N = length(x)
    @assert ispow2(N)
    @assert isfermat(q)

    radix2sort!(x)

    logN = log2(N) |> ceil |> T

    for M in 2 .^ [0:logN-1;] # TODO: not very readable
        interval = 2M
        p = div(N, interval)
        gp = powermod(g, p, q)
        W = 1
        for m in 1:M
            for i in m:interval:N
               j = i + M
               Wxj = mod(W * x[j], q)
               x[i], x[j] = x[i] + Wxj, x[i] - Wxj + q
               x[i] = x[i] >= q ? x[i] - q : x[i]
               x[j] = x[j] >= q ? x[j] - q : x[j]
            end
            W = mod(W*gp, q)
        end
    end

    return x
end

function fnt3(x::Array{T}, g::T, q::T) where {T<:Integer}
    return fnt3!(copy(x), g, q)
end

# Sanity check
const x = mod.(rand(Int, 4096), 65537)
all(fnt3(x, 169, 65537) .== fnt(x, 169, 65537)) # true

@btime fnt3($x, $169, $65537) # 245.112 μs (4 allocations: 32.42 KiB)

EDIT: Changed slightly definition of x[j], it doesn’t affect performance, but initial version introduced bug for unsigned integers, because

mod(UInt16(7) - UInt16(9), UInt16(11))  # 0x0007
mod(UInt16(7) - UInt16(9) + UInt16(11), UInt16(11)) # 0x0009

# and also
UInt16(7) - UInt16(9) < 0 # false
6 Likes

The next performance improvement would be to calculate q^-1 (in the group sense), so your mood operations can be done as x[j]-x[j]q^-1.

3 Likes

@Skoffer, it’s brilliant, all unit tests passed.

julia> @btime fnt($x2, $225, $65537); #15.061 ms (4 allocations: 64.45 KiB)
  644.937 μs (2 allocations: 64.08 KiB)

@tkf, after storing x[i] and x[j] in variables

julia> @btime fnt($x2, $225, $65537); #15.061 ms (4 allocations: 64.45 KiB)
  593.164 μs (2 allocations: 64.08 KiB)

I also added @inbounds to outer loop:

julia> @btime fnt($x2, $225, $65537); #15.061 ms (4 allocations: 64.45 KiB)
  574.171 μs (2 allocations: 64.08 KiB)

Adding @simd has no effect, probably cause my laptop is from 2013. I will check tomorrow on a desktop.

Summing up: 26x speedup, and about 5x slower than FFTW. Amazing, I had no idea that mod operations are so costly.

I assume I can commit your suggestions to my MIT licensed lib, fine?

2 Likes

@Oscar_Smith, I am not sure I understand. q is the modulus, so q == 0 mod q, so it has no inverse.

Actually you can get additional boost by exploiting the fact, that q is fermat number. This is little more involved, of course. Idea is the following:

  1. We need to find a mod q where a < q^2 and q = 2^2^p + 1 where p some integer.
  2. a can be presented in the form a = n*q + a', here a' = a mod q and n some integer.
  3. a can be further presented in the form a = n*q + a' = n*(q - 1) + n + a', so n + a' = a mod (q - 1) = a & 2^2^p - 1, here we used the fact that mod 2^l in binary representation is just logical and with 2^j - 1.
  4. n can be found as a div q - 1 = a div 2^2^p = a >>> 2^p where we have used the fact that in binary representation division by power of 2 equals to corresponding shift.

Together it gives us this nice function

function fermat_mod(a::T, q::T) where T <: Integer
    x = a & (q - T(2)) - a >>> trailing_zeros(q - T(1)) + q
    x = x >= q ? x - q : x
end

in this function it is assumed that q is fermat and a < q^2.

Adding it into fnt! function yields us

function fnt2!(x::Array{T, 1}, g::T, q::T) where {T<:Integer}
    N = length(x)
    @assert ispow2(N)
    @assert isfermat(q)

    radix2sort!(x)

    logN = log2(N) |> ceil |> T

    for M in 2 .^ [0:logN-1;] # TODO: not very readable
        interval = 2M
        p = div(N, interval)
        gp = powermod(g, p, q)
        W = 1
        for m in 1:M
            for i in m:interval:N
               j = i + M
               Wxj = W * x[j]
               Wxj = Wxj & (q - T(2)) - Wxj >>> trailing_zeros(q - T(1)) + q
               Wxj = Wxj >= q ? Wxj + q : Wxj

               x[i], x[j] = x[i] + Wxj, x[i] - Wxj + q
               x[i] = x[i] >= q ? x[i] - q : x[i]
               x[j] = x[j] >= q ? x[j] - q : x[j]
            end
            W = mod(W*gp, q)
        end
    end

    return x
end

And benchmark

@btime fnt2($x, $169, $65537)  # 176.820 μs (4 allocations: 32.42 KiB) 
7 Likes

Excellent! You will beat FFTW till tomorrow :slight_smile:

@btime fnt($x2, $225, $65537);
600.204 μs (4 allocations: 64.45 KiB)

@btime fnt3($x2, $225, $65537);
451.002 μs (4 allocations: 64.45 KiB)

@btime fft($x2); #130.380 μs (53 allocations: 259.02 KiB)
130.582 μs (53 allocations: 259.02 KiB)

451/130 # fnt / fft
3.4692307692307693

Your reasoning is correct, but there were a minor typo that caused errors for inputs close to q.

Wxj = Wxj >= q ? Wxj + q : WxjWxj = Wxj >= q ? Wxj - q : Wxj

This is performance I could not even imagine when writing original post.

5 Likes

Sorry for the confusion, I was typing on my phone while tired. I’m having a slightly hard time explaining what I mean, but https://gmplib.org/~tege/divcnst-pldi94.pdf does a good job of it. The TLDR is the group I was talking about was the Int64s under multiplication. Since modulo can be computed cheaply once you have division, this should speed up the repeated mod q operations.

1 Like

Note that Julia already has something like this in https://github.com/JuliaLang/julia/blob/master/base/multinverses.jl.

julia> smi = Base.multiplicativeinverse(3432)
Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}(3432, 5503923639708211205, 0, 0x0a)

julia> div(21313221, smi)
6210

julia> div(21313221, 3432)
6210
3 Likes

Good to know, that makes this much easier to actually use.

Note that by calling fft(x2) you are not really seeing FFTW’s full speed. Try:

p = plan_fft(x2, flags=FFTW.PATIENT)
@btime $p * $x2

or with pre-allocated output and pre-allocated conversion of x2 to the type that FFTW uses:

x2c = ComplexF64.(x2); y2c = copy(x2c);
@btime mul!($y2c, $p, $x2c);

or taking advantage of the fact that the input is real:

pr = plan_rfft(x2, flags=FFTW.PATIENT);
@btime $pr * $x2;

plus pre-allocated input/output:

x2f = Float64.(x2); y2cr = Array{ComplexF64}(undef, length(x2f)÷2+1);
@btime mul!($y2cr, $pr, $x2f);

With all of the tricks, for the final result I get about a factor of 5–8 faster than fft(x2).

4 Likes

Thanks for explaining, I gave FFT just as an example of algorithm that should have similar complexity to see how slow we are with FNT. I benchmarked this whole plan creation and preallocation and it’s 211 μs, so I understand it will be faster if we have a lot of vectors to transform, but surely you are right here.
I added again suggestions from @tkf, so now fnt() runs 382 μs (451 μs was before). I think this is amazing result for pure Julia implementation.
BTW, are you aware of any pure Julia FFT implementations? Maybe I should compare timings also with them.

Thank you for sharing, it looks very interesting, but @Skoffer eliminated all mod operations, except W = mod(W * gp, q), which is executed just few times. I will try to use it anyway on the original source I posted.

UPDATE: as I promissed I did testing with those multiplicative inverses. I did also some further improvements which gave:

  • 349.411 μs for @Skoffer’s solution
  • 385.504 μs for @Oscar_Smith’s solution
  • 596.341 μs for using Base.mod

Also this multiplicative inverse works just for machine word sized integers. This FNT approach is really only useful with BigInts to get a very high precision in number representation so this is a deal breaker for me.