2D diffusion and crank nicolson

I’m trying to solve 2D diffusion using the crank nicolson method and seem to be screwing it up. For one, Ax=b should be solved with A\b, but gives nonsensical results. Oddly, b\A seems to give something approximating what I’d expect, but I don’t know if I’m getting the oscillations that crank nicolson is prone to, (in which case, I’d ask: how do I get rid of them?), or if I’m so butchering this that these answers are the result of some massive error in my realization of the theory. The solution seems to vacillate from iteration to iteration between the correct solution and infinitesimally small solutions (oscillations?) but also, tends to degenerate to matrices that have alternating signs between neighbors (obviously, approximating smooth diffusion this is incorrect). I’m also pretty sure I butchered the no flux boundary conditions I was going for :slight_smile:

Below is a MWE; visualized either through high res heatmap or in UnicodePlots (convenient for quick look).


using LinearAlgebra, SparseArrays


const v0 = 0.0f0
const D = 0.01f0
const v1 = 100.0f0

const h = 100
const dh = 1/(h-1)
u0 = fill(v0, (h, h));
u0[50:60,50:60] .= v1;   # a small square in the middle of the domain

dt = 0.001
D = 0.01 # diffusion coefficient

a = 1+(D*2*dt/dh^2)
c = -D*dt/(2*dh^2)
a_ = 1-(D*2*dt/dh^2)
c_ = D*dt/(2*dh^2)

A1 = vec(fill(a, (1,h-1)))
A2 = vec(fill(c, (1,h)))
C = vec(fill(c, (1,h)))
diag = Tridiagonal(A1,A2,A1)
upper_lower = Diagonal(C)

# Block tri-diagonal matrix - sparse
μ = [sparse(zeros( h, h)) for i in 1:h, j in 1:h];
μ[diagind(μ)] .= [diag];
μ[diagind(μ, -1)] .= [upper_lower];
μ[diagind(μ, 1)] .= [upper_lower'];
bM = reduce(vcat, [reduce(hcat, μ[i, :]) for i in 1:h]);

# Boundary conditions???
bM[2,1] = 2.0
bM[end-1,end] = 2.0
bM[1,2] = 2.0
bM[end,end-1] = 2.0

function rhs(u, a, c)
    n1, n2 = size(u)
    du = zeros(Float64, n1, n2)

    # internal nodes
    for j = 2:n2-1
        for i = 2:n1-1
            @inbounds  du[i,j] = a * u[i,j] + c * ( u[i+1,j] + u[i-1,j] + u[i,j+1] + u[i,j-1] )
        end
    end

    # left/right edges
    for i = 2:n1-1
        @inbounds du[i,1] = a * u[i,1] + c * ( u[i+1,1] + u[i-1,1] + u[i,2] )
        @inbounds du[i,n2] = a * u[i,n2] + c * ( u[i+1,n2] + u[i-1,n2] + u[i,n2-1] )
    end

    # top/bottom edges
    for j = 2:n2-1
        @inbounds du[1,j] = a * u[1,j] + c * ( u[1,j+1] + u[1,j-1] + u[2,j] )
        @inbounds du[n1,j] = a * u[n1,j] + c * ( u[n1,j+1] + u[n1,j-1] + u[n1-1,j] )
    end

    # corners
    @inbounds du[1,1] = a * u[1,1] + c * ( (u[2,1] + u[1,2]) )
    @inbounds du[n1,1] = a * u[n1,1] + c * ( u[n1-1,1] + u[n1,2] ) 
    @inbounds du[1,n2] = a * u[1,n2] + c * ( u[2,n2] + u[1,n2-1] ) 
    @inbounds du[n1,n2] = a * u[n1,n2] + c * ( u[n1-1,n2] + u[n1,n2-1] ) 

    return du
end

du = []
push!(du, u0)
push!(du, reshape(vec(rhs(u0, a_, c_)) \ bM, (h,h)))
for i in range(1,100;step=dt)

    push!(du, reshape(vec(rhs(du[end], a_, c_)) \ bM, (h,h)))
end

#############################################################################################################

######
# High res plot
# using Plots

# ENV["GKSwstype"]="nul"

# n = size(du, 1)
# anim = @animate for t ∈ collect(range(2,n,step=2)) # 1:n#
#     heatmap(du[t], c = cgrad(:rainbow, scale=(0,100)), clim=(0,100));
# end

# mp4(anim, "./heatmap.mp4", fps=15)

##########################################################################################################################################
### Low res plot

using UnicodePlots

function move_up(s::AbstractString)
    move_up_n_lines(n) = "\u1b[$(n)F"
    # actually string_height - 1, but we're assuming cursor is on the last line
    string_height = length(collect(eachmatch(r"\n", s)))
    print(move_up_n_lines(string_height))
    nothing
end

function animate(frames; frame_delay = 0)
    print("\u001B[?25l") # hide cursor
    for frame in frames[1:end-1]
        print(frame)
        sleep(frame_delay)
        move_up(string(frame))
    end
    print(frames[end])
    print("\u001B[?25h") # visible cursor
    nothing
end

using ThreadsX
frames = ThreadsX.collect([heatmap(du[i], colorbar_lim=(0,100), ylabel=string(i))  for i in range(1,length(du); step=100)]);

animate(frames; frame_delay = 0)

#########################################################################################################################################