Hi everyone.
I wrote a small code to solve 2D linear visco-elastic Stokes using finite difference and OrdinaryDiffEq to test how that would work out. So, in this case, I only have algebraic equations and have some parameters that are time-dependent that I am updating between each solve step (I also know I could use a callback function for that, but I am doing it manually for now to have a bit more controle).
So my whole system is expressed as an ODEproblem with a mass-matrix equal to 0 everywhere. I know that I could use LinearSolve directly for this problem with a simple time-loop, and that works well, but I plan to add other elliptic equations later. The current bottleneck is the init() stage, that takes a bit less than 1 minute for a 40x40 problem and that grows dramatically with increasing resolution. The solving stage is actually quite fast.
My discretisation function is already allocation free and I am providing the Jacobian sparsity so I ran out of ideas to cut down that number. Is there something I could try? I have also tried to solve the initial conditions with LinearSolve.jl to help with the initial guess but that didn’t change the init() time at all. So I am assuming the time is spent on something else. Is that a current known limitation? I have tried Rodas4P() or FBDF() but the init time didn’t really change.
Here is the code, if that can help. It is a bit long sorry!
using OrdinaryDiffEq, LinearAlgebra
using ForwardDiff, SparseArrays
using ComponentArrays:ComponentArray
using PreallocationTools
using SparseArrays
import SparseConnectivityTracer, ADTypes
using LinearSolve
using NonlinearSolve
# 2. Update your total_strainrate! function
function total_strainrate!(εvol, εxx, εzz, εxz, εII, Vx, Vz, dx, dz)
nx, nz = size(εxx)
# 1. Initialize εxz boundary (Crucial for Tracers!)
# Tracers can throw UndefRefError if an index is read before being assigned.
εxz .= 0.0
# 2. Diagonal components (Using simple loops is often safer for AD)
for j in 1:nz, i in 1:nx
dvx_dx = (Vx[i+1, j] - Vx[i, j]) / dx
dvz_dz = (Vz[i, j+1] - Vz[i, j]) / dz
εv = dvx_dx + dvz_dz
εvol[i, j] = εv
εxx[i, j] = dvx_dx - εv / 3.0
εzz[i, j] = dvz_dz - εv / 3.0
end
# 3. Off-diagonal (On Nodes/Vertices)
# We only compute internal nodes; boundaries are handled by BCs later
for j in 2:nz, i in 2:nx
dvx_dz = (Vx[i, j] - Vx[i, j-1]) / dz
dvz_dx = (Vz[i, j] - Vz[i-1, j]) / dx
εxz[i, j] = 0.5 * (dvx_dz + dvz_dx)
end
# 4. Second Invariant (εII) at Cell Centers
for j in 1:nz, i in 1:nx
# Average the 4 surrounding εxz nodes to the cell center
# These indices map nodes (nx+1, nz+1) to centers (nx, nz)
exz_sq_av = 0.25 * (
εxz[i, j]^2 + εxz[i+1, j]^2 +
εxz[i, j+1]^2 + εxz[i+1, j+1]^2
)
εII[i, j] = sqrt(0.5 * εxx[i, j]^2 + 0.5 * εzz[i, j]^2 + exz_sq_av)
end
return nothing
end
function total_stresses!(τxx, τzz, τxz, τII, εxx, εzz, εxz, τxx_old, τzz_old, τxz_old, η, G, Δt)
nx, nz = size(εxx)
# 1. CRITICAL: Initialize the entire τxz array.
# This prevents UndefRefErrors when reading boundary nodes for τII.
τxz .= 0.0
# 2. Update τxx and τzz at cell centers
for j in 1:nz, i in 1:nx
# Average viscosity and bulk modulus from 4 surrounding vertices to the center
η_av = 0.25 * (η[i, j] + η[i+1, j] + η[i, j+1] + η[i+1, j+1])
G_av = 0.25 * (G[i, j] + G[i+1, j] + G[i, j+1] + G[i+1, j+1])
χ_av = η_av / (G_av * Δt[1])
inv_1_χ = 1.0 / (1.0 + χ_av)
τxx[i, j] = (τxx_old[i, j] * χ_av + 2.0 * η_av * εxx[i, j]) * inv_1_χ
τzz[i, j] = (τzz_old[i, j] * χ_av + 2.0 * η_av * εzz[i, j]) * inv_1_χ
end
# 3. Update τxz at internal vertices
for j in 2:nz, i in 2:nx
χ = η[i, j] / (G[i, j] * Δt[1])
inv_1_χ = 1.0 / (1.0 + χ)
τxz[i, j] = (τxz_old[i, j] * χ + 2.0 * η[i, j] * εxz[i, j]) * inv_1_χ
end
# 4. Compute Second Invariant τII at cell centers
for j in 1:nz, i in 1:nx
# Now every index in τxz has been touched by the tracer,
# so this line won't throw UndefRefError anymore.
txz_sq_av = 0.25 * (τxz[i, j]^2 + τxz[i+1, j]^2 + τxz[i, j+1]^2 + τxz[i+1, j+1]^2)
τII[i, j] = sqrt(0.5 * τxx[i, j]^2 + 0.5 * τzz[i, j]^2 + txz_sq_av)
end
return nothing
end
# 3. The Residual Function
function Res!(du, u, p, t)
# A. Retrieve caches
εvol = get_tmp(p.cache_εvol, u)
εxx = get_tmp(p.cache_εxx, u)
εzz = get_tmp(p.cache_εzz, u)
εxz = get_tmp(p.cache_εxz, u)
εII = get_tmp(p.cache_εII, u)
τxx = get_tmp(p.cache_τxx, u)
τzz = get_tmp(p.cache_τzz, u)
τxz = get_tmp(p.cache_τxz, u)
τII = get_tmp(p.cache_τII, u)
η_eff= get_tmp(p.cache_η_eff, u)
# B. Unpack Variables
P, Vx, Vz = u.P, u.Vx, u.Vz
P_shift = p.P_shift[1] # Unpack the shift
dx, dz = p.Δ
nx, nz = size(P)
# C. Physics Logic (The loop-based functions we just wrote)
total_strainrate!(εvol, εxx, εzz, εxz, εII, Vx, Vz, dx, dz)
total_stresses!(τxx, τzz, τxz, τII, εxx, εzz, εxz, p.τxx_old, p.τzz_old, p.τxz_old, p.η, p.G, p.Δt)
# D. Residuals - Loop Based
# 1. Mass Balance & Momentum
for j in 1:nz, i in 1:nx
# Mass balance: div(V) + P/G_av = 0
G_av = 0.25 * (p.G[i, j] + p.G[i+1, j] + p.G[i, j+1] + p.G[i+1, j+1])
du.P[i, j] = εvol[i, j] + G_av * (P[i, j] - P_shift)
# Effective Viscosity (optional, for monitoring)
η_eff[i, j] = τII[i, j] / (2.0 * εII[i, j] + 1e-18)
end
# 2. X-Momentum Balance
for j in 1:nz, i in 1:nx+1
if i == 1
du.Vx[i, j] = Vx[i, j] - p.BC.εxx_bg * p.x[i]
elseif i == nx+1
du.Vx[i, j] = Vx[i, j] - p.BC.εxx_bg * p.x[i]
else
# Internal nodes: -dP/dx + dτxx/dx + dτxz/dz = 0
dτxz_dz = (τxz[i, j+1] - τxz[i, j]) / dz
dτxx_P_dx = ( (τxx[i, j] - (P[i, j] - P_shift)) - (τxx[i-1, j] - (P[i-1, j] - P_shift)) ) / dx
du.Vx[i, j] = dτxx_P_dx + dτxz_dz
end
end
# 3. Z-Momentum Balance
for j in 1:nz+1, i in 1:nx
if j == 1
du.Vz[i, j] = Vz[i, j] - p.BC.εzz_bg * p.z[j]
elseif j == nz+1
du.Vz[i, j] = Vz[i, j] - p.BC.εzz_bg * p.z[j]
else
# Internal nodes: -dP/dz + dτzz/dz + dτxz/dx - ρg = 0
dτxz_dx = (τxz[i+1, j] - τxz[i, j]) / dx
dτzz_P_dz = ( (τzz[i, j] - (P[i, j] - P_shift)) - (τzz[i, j-1] - (P[i, j-1] - P_shift)) ) / dz
# Density averaging (lin_int replacement)
ρ_av = 0.5 * (p.ρ[i, j] + p.ρ[i+1, j])
du.Vz[i, j] = dτzz_P_dz + dτxz_dx - ρ_av * p.g
end
end
return nothing
end
# Setup problem
Nx, Nz = 40, 40
dx, dz = 1.0/Nx, 1.0/Nz
Δ = (dx, dz)
# Boundary and Physical constants
BC = (εxx_bg = 1.0, εzz_bg = -1.0)
P_shift = [0.0]
g = 1.0
Δt = [1e-3]
# Material fields (Must be defined)
ρ = ones(Nx+1, Nz+1)
η = ones(Nx+1, Nz+1)
G = ones(Nx+1, Nz+1)
# History fields (IMPORTANT: These must exist in 'p')
τxx_old = zeros(Nx, Nz)
τzz_old = zeros(Nx, Nz)
τxz_old = zeros(Nx+1, Nz+1)
# ComponentArray for state variables
u = ComponentArray(
P = rand(Nx, Nz),
Vx = rand(Nx+1, Nz),
Vz = rand(Nx, Nz+1)
)
x = range(-0.5, 0.5, Nx+1)
z = range(-1.0, 0.0, Nz+1)
xc = (x[2:end] .+ x[1:end-1]) ./ 2
zc = (z[2:end] .+ z[1:end-1])
# initial guess consistent with BCs
for (i, xv) in enumerate(x), (j, zv) in enumerate(zc); u.Vx[i, j] = BC.εxx_bg * xv; end
for (i, xv) in enumerate(xc), (j, zv) in enumerate(z); u.Vz[i, j] = BC.εzz_bg * zv; end
u.P .= 0.0
du = similar(u)
# Setup caches
AD_CHUNK = 12
p = (;
x = range(-0.5, 0.5, Nx+1),
z = range(-1.0, 0.0, Nz+1),
Δ, BC, P_shift, g, Δt, ρ, η, G,
τxx_old, τzz_old, τxz_old,
cache_εvol = dualcache(zeros(Nx, Nz), AD_CHUNK),
cache_εxx = dualcache(zeros(Nx, Nz), AD_CHUNK),
cache_εzz = dualcache(zeros(Nx, Nz), AD_CHUNK),
cache_εxz = dualcache(zeros(Nx+1, Nz+1), AD_CHUNK),
cache_εII = dualcache(zeros(Nx, Nz), AD_CHUNK),
cache_τxx = dualcache(zeros(Nx, Nz), AD_CHUNK),
cache_τzz = dualcache(zeros(Nx, Nz), AD_CHUNK),
cache_τxz = dualcache(zeros(Nx+1, Nz+1), AD_CHUNK),
cache_τII = dualcache(zeros(Nx, Nz), AD_CHUNK),
cache_η_eff= dualcache(zeros(Nx, Nz), AD_CHUNK)
)
# Test call with Float64 to ensure logic is sound
Res!(du, u, p, 0.0)
println("Residual test passed with Float64")
# Now run the Sparsity Detection
detector = SparseConnectivityTracer.TracerSparsityDetector()
jac_sparsity = ADTypes.jacobian_sparsity(
(du_vec, u_vec) -> Res!(du_vec, u_vec, p, 0.0),
du,
u,
detector
)
println("Jacobian sparsity pattern detected.")
n_total = length(u)
# Create a sparse matrix of zeros
M = Diagonal(zeros(n_total))
jac_prototype = float.(jac_sparsity) # Convert to actual sparse matrix
f = ODEFunction(Res!;
mass_matrix = M,
jac_prototype = jac_prototype
)
t = (0.0, 0.4)
prob = ODEProblem(f, u, t, p, save_everystep = false)
@time integ = init(prob, Rodas5P())
while integ.t < t[end]
# update Δt dynamically
integ.p.Δt[1] = integ.dt # update parameters for next step
integ.p.τxx_old .= get_tmp(integ.p.cache_τxx, integ.u)
integ.p.τzz_old .= get_tmp(integ.p.cache_τzz, integ.u)
integ.p.τxz_old .= get_tmp(integ.p.cache_τxz, integ.u)
integ.p.P_shift[1] = mean(integ.u.P) # update pressure shift
println("Current t = $(integ.t)")
step!(integ)
end
