I’m trying to perform a convolution using CUDA fft and something puzzle me when comparing CPU and GPU result:
using Adapt, CUDA, FFTW, AbstractFFTs
T = Float32
sz = (2048,2048)
x = randn(T,sz);
x_gpu = adapt(CuArray,x);
d = randn(Complex{T},(1025,2048))
d_gpu = adapt(CuArray,d)
GPU and CPU output are not equal when the backward brfft gives a (2048,2048) output but they are equal for a (2049,2048) output:
julia> adapt(Array, brfft(d_gpu .* rfft(x_gpu), 2048)) ≈ brfft(d .* rfft(x),2048)
false
julia> adapt(Array, brfft(d_gpu .* rfft(x_gpu), 2049)) ≈ brfft(d .* rfft(x),2049)
true
julia> adapt(Array, brfft(rfft(x_gpu), 2048)) ≈ brfft(rfft(x),2048)
true
julia> adapt(Array, d_gpu) ≈ d
true
julia> adapt(Array, d_gpu .* rfft(x_gpu)) ≈ (d .* rfft(x))
true
and it is not just rounding error:
julia> adapt(Array, brfft(d_gpu .* rfft(x_gpu), 2048)) .- brfft(d .* rfft(x),2048)
2048×2048 Matrix{Float32}:
-8465.38 13274.8 -119964.0 1.24876f5 7778.25 -1.24958f5 … 33341.4 18242.2 -75873.0 -30010.5 -1052.38 228162.0 -1.99176f5
-1.00414f5 -51067.1 1.1154f5 -54599.9 25838.5 -96901.0 -18502.0 26938.0 -22623.0 -31004.0 -36856.9 -5554.5 48698.5
-8464.75 13277.2 -1.19968f5 1.24872f5 7778.88 -124956.0 33340.8 18242.8 -75874.0 -30015.0 -1050.0 2.28162f5 -1.99173f5
-1.00412f5 -51064.5 1.11541f5 -54601.5 25842.0 -96899.5 -18504.0 26935.0 -22622.5 -31004.0 -36852.8 -5557.5 48695.5
-8466.75 13277.8 -1.19967f5 1.24874f5 7777.0 -1.24954f5 33341.5 18242.5 -75875.1 -30013.0 -1050.0 2.2816f5 -199175.0
-1.00414f5 -51066.6 111541.0 -54601.5 25840.0 -96897.4 … -18502.2 26937.6 -22622.8 -31005.4 -36857.1 -5553.75 48695.0
⋮ ⋮ ⋱ ⋮
-8465.0 13274.6 -1.19964f5 1.24875f5 7781.88 -1.24954f5 33339.8 18240.2 -75877.8 -30018.4 -1054.5 2.2816f5 -1.99176f5
-100414.0 -51064.0 1.11544f5 -54599.1 25837.2 -96902.5 -18507.0 26936.2 -22623.2 -31008.0 -36857.5 -5552.5 48692.8
-8463.5 13273.8 -1.19966f5 1.24875f5 7776.12 -1.24952f5 33341.0 18241.5 -75874.0 -30012.0 -1052.75 2.28162f5 -199176.0
-1.00414f5 -51064.0 1.11542f5 -54601.0 25840.9 -96903.0 … -18503.0 26940.1 -22624.0 -31004.8 -36853.8 -5556.0 48693.8
-8462.0 13276.5 -119968.0 1.24873f5 7778.75 -1.24956f5 33337.2 18241.0 -75875.8 -30011.4 -1052.75 2.28161f5 -1.99174f5
-100413.0 -51062.0 1.1154f5 -54601.5 25843.5 -96899.5 -18498.5 26935.5 -22623.8 -31006.8 -36856.8 -5554.5 48695.5
I can’t see where is my issue. Any hint?