FFT convolution heat equation

Hello,

I am looking at the heat equation \partial_t u(x,t) = \alpha \partial^2_{x} u(x,t) on \mathbb{R} with initial condition u(x,0) = f(x)

The analytical solution is given by u(x,t) = G \star f, where G(x,t) = \frac{1}{\sqrt {4\pi \alpha t}} \exp(-\frac{1}{4 \alpha t}x^2) is the heat kernel and \star denotes the convolution.

I would like to compute u(x,t) by convolution using a Fourier transform. I also solved the heat equation in spectral domain for comparison, see note by Steven Johnson (https://math.mit.edu/~stevenj/fft-deriv.pdf).

The current result of the convolution with FFT does not diffuse the initial condition. I would greatly appreciate your help.

# Set heat problem
α = 1.0
tf = 5.0
L = 50.0
N = 256
Δx = L/N
xgrid = collect(Δx*(0:1:N-1));

# Initial function
f(x) = exp(-(x-L/2)^2)

# Heat kernel and its exact Fourier transform
G(x,t) = 1/sqrt(4*π*α*t)*exp(-(x)^2/(4*α*t)) 
Ĝ(k,t) = exp(-α*k^2*t)

# Discrete initial condition
fd = zeros(N)
for (i, xi) in enumerate(xgrid)
   fd[i] =  f(xi)    
end
f̂d = fft(fd);

# Discrete heat kernel at t=1
G1d = zeros(N)

for (i, xi) in enumerate(xgrid)
   G1d[i] =  G(xi, 1.0)    
end

Ĝ1d = fft(G1d);

kgrid = fftfreq(N, Δx);

Ĝ1_true = map(k-> Ĝ(k,1.0), kgrid);

# Compute solution by convolution
û1d = Ĝ1d .* f̂d;
û1_true = Ĝ1_true .* f̂d;


# Numerical solution of the heat equation in spectral domain
params = Dict("N" => N,
              "L" => L,
              "uhat" => zeros(ComplexF64, N),
              "duhat" => zeros(ComplexF64, N),
              "plan" => plan_fft(zeros(N)),
              "α" => α)

function heat_equation_spectral!(duhat, uhat, p, t)
    α = params["α"]
    N = params["N"]
    L = params["L"]
    
    for k=0:N-1
        if k <= N÷2
            duhat[k+1] = α*(-(2*π*k/L)^2)*uhat[k+1]
        elseif k > N÷2
            duhat[k+1] = α*(-(2*π/L*(k-N))^2)*uhat[k+1]
        end
    end
end

prob_spectral = ODEProblem((duhat,uhat,p,t) ->heat_equation_spectral!(duhat,uhat,params, t), 
                   f̂d, (0.0, tf));

sol_spectral = solve(prob_spectral, Tsit5());

# Comparison of the results
plot(xgrid,real(ifft(sol_spectral(0.0))), label = L"t = 0.0", linewidth= 3)
plot!(xgrid,real(ifft(sol_spectral(1.0))), label = L"t = 1.0", linewidth= 3)
plot!(xgrid, real(ifft(û1d)), label = "Convolution with discrete heat kernel", linewidth= 3)
plot!(xgrid, real(ifft(û1_true)), label = "Convolution with heat kernel", linewidth= 3)

Please note that if we make the grid and the initial condition symmetrical around the origin, then the code produces good results:

Thank you @rafael.guerra for your help. I have modified the code such that f and G are defined on [0, L] and centered about L/2 (we want f and G to be close to zero at the boundaries of the domain). But I am still getting the wrong result. Could you share your implementation?

Also, on your graph, we expect the black curve to overlap with the red and yellow curves. I am not sure how to get the correct wavenumbers to sample \hat{G}.

# Set heat problem
α = 1.0
tf = 5.0
L = 50.0
N = 256
Δx = L/N
xgrid = collect(Δx*(0:1:N-1));

# Initial function
f(x) = exp(-(x-L/2)^2)

# Heat kernel and its exact Fourier transform
G(x,t) = 1/sqrt(4*π*α*t)*exp(-(x)^2/(4*α*t)) 
Ĝ(k,t) = exp(-α*k^2*t)

# Discrete initial condition
fd = zeros(N)
for (i, xi) in enumerate(xgrid)
   fd[i] =  f(xi)    
end
f̂d = fft(fd);

# Discrete heat kernel at t=1
G1d = zeros(N)

for (i, xi) in enumerate(xgrid)
   G1d[i] =  G(xi-L/2, 1.0)    
end

Ĝ1d = fft(G1d);

kgrid = fftfreq(N, Δx);

Ĝ1_true = map(k-> Ĝ(k,1.0), kgrid);

# Compute solution by convolution
û1d = Ĝ1d .* f̂d;
û1_true = Ĝ1_true .* f̂d;


# Numerical solution of the heat equation in spectral domain
params = Dict("N" => N,
              "L" => L,
              "uhat" => zeros(ComplexF64, N),
              "duhat" => zeros(ComplexF64, N),
              "plan" => plan_fft(zeros(N)),
              "α" => α)

function heat_equation_spectral!(duhat, uhat, p, t)
    α = params["α"]
    N = params["N"]
    L = params["L"]
    
    for k=0:N-1
        if k <= N÷2
            duhat[k+1] = α*(-(2*π*k/L)^2)*uhat[k+1]
        elseif k > N÷2
            duhat[k+1] = α*(-(2*π/L*(k-N))^2)*uhat[k+1]
        end
    end
end

prob_spectral = ODEProblem((duhat,uhat,p,t) ->heat_equation_spectral!(duhat,uhat,params, t), 
                   f̂d, (0.0, tf));

sol_spectral = solve(prob_spectral, Tsit5());

# Comparison of the results
plot(xgrid,real(ifft(sol_spectral(0.0))), label = L"t = 0.0", linewidth= 3)
plot!(xgrid,real(ifft(sol_spectral(1.0))), label = L"t = 1.0", linewidth= 3, color = :orange)
plot!(xgrid, real(ifft(û1d)), label = "Convolution with discrete heat kernel", linewidth= 2, color = :red)
plot!(xgrid, real(ifft(û1_true)), label = "Convolution with heat kernel", linewidth= 2, color = :black)

As requested, please find below the edited code used. The main change was to make the support spatial grid and initial condition symmetric around x=0 (instead of L/2).
I don’ t have time now to look at the issue with the black curve.

Edited code
using FFTW

# Set heat problem
α = 1.0
tf =  5.0
L = 20.0
N = 256
Δx = L/N
xgrid = Δx*(-N÷2:N÷2-1)

# Initial function
f(x) = exp(-x^2)

# Heat kernel and its exact Fourier transform
G(x,t) = 1/sqrt(4*π*α*t)*exp(-x^2/(4*α*t)) 
Ĝ(k,t) = exp(-α*k^2*t)

# Discrete initial condition
fd = zeros(N)
for (i, xi) in enumerate(xgrid)
   fd[i] =  f(xi)    
end
f̂d = fft(fd)*Δx

# Discrete heat kernel at t=1
td = 2.0
G1d = zeros(N)
for (i, xi) in enumerate(xgrid)
   G1d[i] =  G(xi, td)    
end
Ĝ1d = fft(G1d)*Δx

kgrid = fftfreq(N, Δx)

Ĝ1_true = Ĝ.(kgrid,td)

# Compute solution by convolution
û1d = Ĝ1d .* f̂d
û1_true = Ĝ1_true .* f̂d


# Numerical solution of the heat equation in spectral domain
using OrdinaryDiffEq

params = Dict("N" => N,
              "L" => L,
              "uhat" => zeros(ComplexF64, N),
              "duhat" => zeros(ComplexF64, N),
              "plan" => plan_fft(zeros(N)),
              "α" => α)

function heat_equation_spectral!(duhat, uhat, p, t)
    α = params["α"]
    N = params["N"]
    L = params["L"]
    
    for k=0:N-1
        if k <= N÷2
            duhat[k+1] = α*(-(2*π*k/L)^2)*uhat[k+1]
        elseif k > N÷2
            duhat[k+1] = α*(-(2*π/L*(k-N))^2)*uhat[k+1]
        end
    end
end

prob_spectral = ODEProblem((duhat,uhat,p,t) ->heat_equation_spectral!(duhat,uhat,params, t), f̂d, (0.0, tf));

sol_spectral = solve(prob_spectral, Tsit5());

# Comparison of the results
using Plots, LaTeXStrings; gr(dpi=600)
plot(xgrid,real(ifft(sol_spectral(0.0))), label = L"t = 0.0", lw=5)
plot!(xgrid,real(ifft(sol_spectral(td))), label = L"t = %$td", lw=5)
plot!(xgrid, real(fftshift(ifft((û1d)))), label = "Convolution with discrete heat kernel", lw=2, lc=:yellow)
plot!(xgrid, real(ifft(û1_true)), label = "Convolution with heat kernel", lw=2, lc=:black)
1 Like

Thanks a lot for your help. I found the solution to the second issue. It is still unclear to me why we need to use the function fftshit for the convolution with FFT. Also the output of ifft should be scaled by 1/\Delta x.

I attach the correct code:

using FFTW

# Set heat problem
α = 1.0
tf =  5.0
L = 20.0
N = 256
Δx = L/N
xgrid = Δx*(-N÷2:N÷2-1)

# Initial function
f(x) = exp(-x^2) + 1.0*exp(-(x-4)^2/2)

# Heat kernel and its exact Fourier transform
G(x,t) = 1/sqrt(4*π*α*t)*exp(-x^2/(4*α*t)) 
Ĝ(k,t) = exp(-α*k^2*t)

# Discrete initial condition
fd = zeros(N)
for (i, xi) in enumerate(xgrid)
   fd[i] =  f(xi)    
end
f̂d = fft(fd)*Δx

# Discrete heat kernel at t=1
td = 2.0
G1d = zeros(N)
for (i, xi) in enumerate(xgrid)
   G1d[i] =  G(xi, td)    
end
Ĝ1d = fft(G1d)*Δx

kgrid = (2*π/L)*collect(-N/2:1:N/2-1)
kgrid = fftshift(kgrid)

Ĝ1_true = Ĝ.(kgrid,td)

# Compute solution by convolution
û1d = Ĝ1d .* f̂d
û1_true = Ĝ1_true .* f̂d

u1d = fftshift(ifft(û1d))/Δx
u1_true = ifft(û1_true)/Δx

# Numerical solution of the heat equation in spectral domain
using OrdinaryDiffEq

params = Dict("N" => N,
              "L" => L,
              "uhat" => zeros(ComplexF64, N),
              "duhat" => zeros(ComplexF64, N),
              "plan" => plan_fft(zeros(N)),
              "α" => α)

function heat_equation_spectral!(duhat, uhat, p, t)
    α = params["α"]
    N = params["N"]
    L = params["L"]
    
    for k=0:N-1
        if k <= N÷2
            duhat[k+1] = α*(-(2*π*k/L)^2)*uhat[k+1]
        elseif k > N÷2
            duhat[k+1] = α*(-(2*π/L*(k-N))^2)*uhat[k+1]
        end
    end
end

prob_spectral = ODEProblem((duhat,uhat,p,t) ->heat_equation_spectral!(duhat,uhat,params, t), fft(f.(xgrid)), (0.0, tf));

sol_spectral = solve(prob_spectral, Tsit5());

# Comparison of the results
using Plots, LaTeXStrings; gr(dpi=600)
plot(xgrid, f.(xgrid), label = L"t = 0.0", lw=6)
plot!(xgrid,real(ifft(sol_spectral(0.0))), label = L"t = 0.0", lw=4)
plot!(xgrid,real(ifft(sol_spectral(td))), label = L"t = %$td", lw=6)
plot!(xgrid, real(u1d), label = "Convolution with discrete heat kernel", lw=4, lc=:yellow)
plot!(xgrid, real(u1_true), label = "Convolution with heat kernel", lw=2, lc=:black)
1 Like

It is still unclear to me why we need to use the function fftshift for the convolution with FFT. Also the output of ifft should be scaled by 1/\Delta x.

Both are for general-case efficiency of the FFT algorithm. Normalizing would require an extra pass and floating-point operation through the data array, and shifting would require a shift. Since these steps are not always needed by users or can be folded into other operations by users, FFT and IFFT algorithms typically leave them out.

2 Likes

The bfft function leaves out a 1/n normalization factor (which can indeed often be absorbed into other calculations), but ifft is normalized (at the cost of an extra pass over the array) so that ifft(fft(x)) ≈ x (i.e. it is a true inverse DFT).

fft computes a discrete Fourier transform (DFT). If you want to think of it as an approximation to a Fourier series you typically need some additional scaling, but that’s not part of the DFT. And the DFT origin is what it is; fftshift is mainly convenient for visualization, but numerical computations can always be arranged to use the “natural” DFT frequency ordering and symmetries without explicit shifts.

3 Likes

FYI, this can be simplified to just

fd =  f.(xgrid)
...
G1d = G.(xgrid, 1)