Dear all,
I am trying to write a basic script that extract sthe wave numbers of a 2D signal (defined in ‘spatial coordinates’) using FFTW.
To test the script, I want to define a signal where the wave numbers in x and y are predefined. Then, I would like to perform a spectral analysis and I expect that this analysis returns the a priori defined wave numbers (for each spatial dimension).
So far my script fails achieving this. Does anyone see obvious error or misuses of FFTW?
Thanks and cheers!
# Define the 2D signal with known wavenumbers
Lx, Ly = 10, 10 # Domain size in x and y directions
Nx, Ny = 256, 256 # Number of samples in x and y directions
dx, dy = Lx / Nx, Ly / Ny # Spatial step sizes
# Coordinate arrays
x, y = LinRange(0, Lx, Nx), LinRange(0, Ly, Ny)
# Define the wavenumbers in the x and y directions
kx, ky = 2 * π * 10 / Lx, 2 * π * 2 / Ly
# Create a signal with two sinusoidal components in the 2D space
signal = cos.(kx * x .+ ky * y')
# Compute the 2D FFT of the signal
fft_signal = fft(signal)
# Frequency bins in x and y directions
fx = fftfreq(Nx, dx)
fy = fftfreq(Ny, dy)
# Shift the FFT (move zero frequency to the center)
fft_shifted = fftshift(fft_signal)
# Magnitude
magnitude = abs.(fft_shifted)
# Find dominant frequencies
max_idx = argmax(magnitude)
# Convert Cartesian index to row and column
max_row, max_col = Tuple(max_idx)
# Get the corresponding (fx, fy) of the peak (frequencies)
peak_fx = fx[max_col]
peak_fy = fy[max_row]
# Wavenumber in x and y
detected_kx = 2 * π * peak_fx
detected_ky = 2 * π * peak_fy
# Print detected and input wavenumbers
println("Detected wavenumber in x: $detected_kx cycles/unit")
println("Detected wavenumber in y: $detected_ky cycles/unit")
println("Input wavenumber in x: $kx cycles/unit")
println("Input wavenumber in y: $ky cycles/unit")
p1 = heatmap(x, y, signal', xlabel="x", ylabel="y", title="Original 2D Signal", aspect_ratio=1)
p2 = heatmap(sort(fx), sort(fy), magnitude, xlabel="Frequency in x", ylabel="Frequency in y", title="2D Magnitude Spectrum of FFT", aspect_ratio=1)
plot(p1, p2)
The issue is in the way the bins are computed, I believe there are two mistakes. The following should fix it:
# Frequency bins in x and y directions
fx = fftshift(fftfreq(Nx, 1/dx))
fy = fftshift(fftfreq(Ny, 1/dy))
You need also to multiply by 2*pi the detected wavenumbers:
# Print detected and input wavenumbers
println("Detected wavenumber in x: $(2*pi*detected_kx) cycles/unit")
println("Detected wavenumber in y: $(2*pi*detected_ky) cycles/unit")
1 Like
Also, your spatial sampling dx
and dy
are not well computed, as you can confirm by doing diff(x)
.
I recommend you do:
x, y = LinRange(0, Lx, Nx+1)[1:Nx], LinRange(0, Ly, Ny+1)[1:Ny]
The resulting dx, dy should be as in your:
dx, dy = Lx / Nx, Ly / Ny
and it should have better endpoint properties.
1 Like
Thanks a lot! This is all working now. Here’s the updated code:
# Define the 2D signal with known wavenumbers
Lx, Ly = 10, 10 # Domain size in x and y directions (in arbitrary units)
Nx, Ny = 256, 256 # Number of samples in x and y directions
dx, dy = Lx / Nx, Ly / Ny # Spatial step sizes
# Generate the 1D grid for x and y, ensure they are sorted in ascending order
x, y = LinRange(0+dx/2, Lx-dx/2, Nx), LinRange(0+dy/2, Ly-dy/2, Ny)
# Define the wavenumbers in the x and y directions
kx, ky = 2 * π * 10 / Lx, 2 * π * 2 / Ly
# Create a signal with two sinusoidal components in the 2D space
signal = cos.(kx * x .+ ky * y')
# Compute the 2D FFT of the signal using fft
fft_signal = fft(signal)
# Frequency bins in x and y directions using fftfreq
fx = fftshift(fftfreq(Nx, 1/dx))
fy = fftshift(fftfreq(Ny, 1/dy))
# Shift the FFT output for better visualization (move zero frequency to the center)
fft_shifted = fftshift(fft_signal)
# Compute the magnitude of the 2D FFT
magnitude = abs.(fft_shifted)
# Find dominant frequencies
max_idx = argmax(magnitude)
# Convert Cartesian index to row and column
max_row, max_col = Tuple(max_idx)
# Get the corresponding (fx, fy) of the peak (frequencies)
peak_fx = fx[max_row]
peak_fy = fy[max_col]
# Wavenumber k = 2 * π * frequency
detected_kx = 2 * π * peak_fx # Wavenumber in the x-direction (physical space)
detected_ky = 2 * π * peak_fy # Wavenumber in the y-direction (physical space)
# Print detected and input wavenumbers
println("Detected wavenumber in x: $(detected_kx) cycles/unit")
println("Detected wavenumber in y: $(detected_ky) cycles/unit")
println("Input wavenumber in x: $kx cycles/unit")
println("Input wavenumber in y: $ky cycles/unit")
p1 = heatmap(x, y, signal', xlabel="x", ylabel="y", title="Original 2D Signal", aspect_ratio=1)
p2 = heatmap(sort(fx), sort(fy), magnitude, xlabel="Frequency in x", ylabel="Frequency in y", title="2D Magnitude Spectrum of FFT", aspect_ratio=1)
plot(p1, p2)