In my Bachelor thesis I need to solve a nonlinear wave equation using the pseudospectral method that relies on the basic property of the Fourier transform
\mathcal{F}(\frac{d^n f}{dx^n}) = (ik)^n \mathcal{F}(f)
where i is the imaginary number and k is the wave number. I started by writing the code in Python, but since I need to solve for 2^{12} points in space and actually there are two coupled PDEs I have 2^{13} coupled ODEs to solve, so the solving was very slow. Along the way of optimization I found Julia, but no equivalent to diff function in scipy.fftpack.
My implementation of the diff function
module DiffFFT
export diff, DiffPlan
using FFTW
@doc raw"""
Differentiates a vector `v` that is assumed to have a periodical boundaries. Using the identity
```math
F(\frac{\partial^n}{\partial x^n} v(x)) = (ik)^n F(v(x))
\frac{\partial^n}{\partial x^n} v(x) = F^{-1}((ik)^n F(v(\omega)))```
v: Vector
n: order of the derivative
period: the period
"""
function diff(v, n, period)
N = length(v) # number of points
k = fftfreq(N, 2*pi/period)*N # wave number
d = ifft((1im*k).^n .* fft(v)) # calculate the derivative
return real.(d)
end
struct DiffPlan
N::Int64
period::Float64
k::Vector{Float64}
fft::FFTW.cFFTWPlan
ifft::FFTW.ScaledPlan
function DiffPlan(N, period)
new(N,
period,
fftfreq(N, 2π/period)*N,
plan_fft(collect(Float64, 1:N), flags=FFTW.PATIENT),
plan_ifft(fft(collect(Float64, 1:N)), flags=FFTW.PATIENT)
)
end
end
function diff(plan::DiffPlan, v, n)
return real.(plan.ifft*((1im*plan.k).^n .* (plan.fft*v)))
end
function diff(d, plan::DiffPlan, v, n)
d[:] = real.(plan.ifft*((1im*plan.k).^n .* (plan.fft*v)))
nothing
end
end # module DiffFFT
The speed is in the same range as calling scipy’s diff from Julia. The initial conditions I used
N = 2^12
period = 48π
X = collect(range(-period/2, period/2, N))
U = sech.(0.2.*(X)).^2
And taking the first derivative with both methods
plan = DiffPlan(N, period)
@benchmark DiffFFT.diff(plan, U, 1)
# BenchmarkTools.Trial: 10000 samples with 1 evaluation.
# Range (min … max): 67.700 μs … 4.218 ms ┊ GC (min … max): 0.00% … 94.62%
# Time (median): 92.100 μs ┊ GC (median): 0.00%
# Time (mean ± σ): 147.815 μs ± 194.474 μs ┊ GC (mean ± σ): 10.39% ± 7.96%
# Memory estimate: 352.41 KiB, allocs estimate: 17.
fftpack = pyimport("scipy.fftpack")
@benchmark fftpack.diff(U, 1, period)
# BenchmarkTools.Trial: 10000 samples with 1 evaluation.
# Range (min … max): 84.600 μs … 11.058 ms ┊ GC (min … max): 0.00% … 51.68%
# Time (median): 117.300 μs ┊ GC (median): 0.00%
# Time (mean ± σ): 127.163 μs ± 251.314 μs ┊ GC (mean ± σ): 2.28% ± 1.15%
# Memory estimate: 33.67 KiB, allocs estimate: 42.
However, the fft operation itself outperforms scipy’s by a factor of 3, so I wonder if it could be made even faster?
fftplan = plan_fft(U, flags=FFTW.PATIENT)
@benchmark fftplan*U
# BenchmarkTools.Trial: 10000 samples with 1 evaluation.
# Range (min … max): 21.600 μs … 4.617 ms ┊ GC (min … max): 0.00% … 97.69%
# Time (median): 30.400 μs ┊ GC (median): 0.00%
# Time (mean ± σ): 44.879 μs ± 119.376 μs ┊ GC (mean ± σ): 12.13% ± 5.00%
# Memory estimate: 128.09 KiB, allocs estimate: 4.
spfft = pyimport("scipy.fft")
@benchmark spfft.fft(U)
# BenchmarkTools.Trial: 10000 samples with 1 evaluation.
# Range (min … max): 66.700 μs … 12.680 ms ┊ GC (min … max): 0.00% … 17.22%
# Time (median): 108.700 μs ┊ GC (median): 0.00%
# Time (mean ± σ): 124.106 μs ± 295.472 μs ┊ GC (mean ± σ): 2.99% ± 1.32%
# Memory estimate: 65.48 KiB, allocs estimate: 35.
I am aware of the @code_warntype
macro that does show my diff function has some red Any types the compiler can’t solve for. But I am not competent enough for solving these problems.
The optimization is not crucial for my thesis, the Python code works well. This mostly out of curiosity and an opportunity to learn about Julia for me.