Stokes solver as an ODEproblem takes time to initiate

Hi everyone.

I wrote a small code to solve 2D linear visco-elastic Stokes using finite difference and OrdinaryDiffEq to test how that would work out. So, in this case, I only have algebraic equations and have some parameters that are time-dependent that I am updating between each solve step (I also know I could use a callback function for that, but I am doing it manually for now to have a bit more controle).

So my whole system is expressed as an ODEproblem with a mass-matrix equal to 0 everywhere. I know that I could use LinearSolve directly for this problem with a simple time-loop, and that works well, but I plan to add other elliptic equations later. The current bottleneck is the init() stage, that takes a bit less than 1 minute for a 40x40 problem and that grows dramatically with increasing resolution. The solving stage is actually quite fast.

My discretisation function is already allocation free and I am providing the Jacobian sparsity so I ran out of ideas to cut down that number. Is there something I could try? I have also tried to solve the initial conditions with LinearSolve.jl to help with the initial guess but that didn’t change the init() time at all. So I am assuming the time is spent on something else. Is that a current known limitation? I have tried Rodas4P() or FBDF() but the init time didn’t really change.

Here is the code, if that can help. It is a bit long sorry!

using OrdinaryDiffEq, LinearAlgebra
using ForwardDiff, SparseArrays
using ComponentArrays:ComponentArray
using PreallocationTools
using SparseArrays
import SparseConnectivityTracer, ADTypes
using LinearSolve
using NonlinearSolve

# 2. Update your total_strainrate! function
function total_strainrate!(εvol, εxx, εzz, εxz, εII, Vx, Vz, dx, dz)
    nx, nz = size(εxx)

    # 1. Initialize εxz boundary (Crucial for Tracers!)
    # Tracers can throw UndefRefError if an index is read before being assigned.
    εxz .= 0.0

    # 2. Diagonal components (Using simple loops is often safer for AD)
    for j in 1:nz, i in 1:nx
        dvx_dx = (Vx[i+1, j] - Vx[i, j]) / dx
        dvz_dz = (Vz[i, j+1] - Vz[i, j]) / dz

        εv = dvx_dx + dvz_dz
        εvol[i, j] = εv
        εxx[i, j]  = dvx_dx - εv / 3.0
        εzz[i, j]  = dvz_dz - εv / 3.0
    end

    # 3. Off-diagonal (On Nodes/Vertices)
    # We only compute internal nodes; boundaries are handled by BCs later
    for j in 2:nz, i in 2:nx
        dvx_dz = (Vx[i, j] - Vx[i, j-1]) / dz
        dvz_dx = (Vz[i, j] - Vz[i-1, j]) / dx
        εxz[i, j] = 0.5 * (dvx_dz + dvz_dx)
    end

    # 4. Second Invariant (εII) at Cell Centers
    for j in 1:nz, i in 1:nx
        # Average the 4 surrounding εxz nodes to the cell center
        # These indices map nodes (nx+1, nz+1) to centers (nx, nz)
        exz_sq_av = 0.25 * (
            εxz[i, j]^2 + εxz[i+1, j]^2 +
            εxz[i, j+1]^2 + εxz[i+1, j+1]^2
        )
        εII[i, j] = sqrt(0.5 * εxx[i, j]^2 + 0.5 * εzz[i, j]^2 + exz_sq_av)
    end

    return nothing
end

function total_stresses!(τxx, τzz, τxz, τII, εxx, εzz, εxz, τxx_old, τzz_old, τxz_old, η, G, Δt)
    nx, nz = size(εxx)

    # 1. CRITICAL: Initialize the entire τxz array.
    # This prevents UndefRefErrors when reading boundary nodes for τII.
    τxz .= 0.0

    # 2. Update τxx and τzz at cell centers
    for j in 1:nz, i in 1:nx
        # Average viscosity and bulk modulus from 4 surrounding vertices to the center
        η_av = 0.25 * (η[i, j] + η[i+1, j] + η[i, j+1] + η[i+1, j+1])
        G_av = 0.25 * (G[i, j] + G[i+1, j] + G[i, j+1] + G[i+1, j+1])

        χ_av = η_av / (G_av * Δt[1])
        inv_1_χ = 1.0 / (1.0 + χ_av)

        τxx[i, j] = (τxx_old[i, j] * χ_av + 2.0 * η_av * εxx[i, j]) * inv_1_χ
        τzz[i, j] = (τzz_old[i, j] * χ_av + 2.0 * η_av * εzz[i, j]) * inv_1_χ
    end

    # 3. Update τxz at internal vertices
    for j in 2:nz, i in 2:nx
        χ = η[i, j] / (G[i, j] * Δt[1])
        inv_1_χ = 1.0 / (1.0 + χ)
        τxz[i, j] = (τxz_old[i, j] * χ + 2.0 * η[i, j] * εxz[i, j]) * inv_1_χ
    end

    # 4. Compute Second Invariant τII at cell centers
    for j in 1:nz, i in 1:nx
        # Now every index in τxz has been touched by the tracer,
        # so this line won't throw UndefRefError anymore.
        txz_sq_av = 0.25 * (τxz[i, j]^2 + τxz[i+1, j]^2 + τxz[i, j+1]^2 + τxz[i+1, j+1]^2)

        τII[i, j] = sqrt(0.5 * τxx[i, j]^2 + 0.5 * τzz[i, j]^2 + txz_sq_av)
    end

    return nothing
end

# 3. The Residual Function
function Res!(du, u, p, t)
    # A. Retrieve caches
    εvol = get_tmp(p.cache_εvol, u)
    εxx  = get_tmp(p.cache_εxx, u)
    εzz  = get_tmp(p.cache_εzz, u)
    εxz  = get_tmp(p.cache_εxz, u)
    εII  = get_tmp(p.cache_εII, u)
    τxx  = get_tmp(p.cache_τxx, u)
    τzz  = get_tmp(p.cache_τzz, u)
    τxz  = get_tmp(p.cache_τxz, u)
    τII  = get_tmp(p.cache_τII, u)
    η_eff= get_tmp(p.cache_η_eff, u)

    # B. Unpack Variables
    P, Vx, Vz = u.P, u.Vx, u.Vz
    P_shift = p.P_shift[1] # Unpack the shift
    dx, dz = p.Δ
    nx, nz = size(P)

    # C. Physics Logic (The loop-based functions we just wrote)
    total_strainrate!(εvol, εxx, εzz, εxz, εII, Vx, Vz, dx, dz)
    total_stresses!(τxx, τzz, τxz, τII, εxx, εzz, εxz, p.τxx_old, p.τzz_old, p.τxz_old, p.η, p.G, p.Δt)

    # D. Residuals - Loop Based

    # 1. Mass Balance & Momentum
    for j in 1:nz, i in 1:nx
        # Mass balance: div(V) + P/G_av = 0
        G_av = 0.25 * (p.G[i, j] + p.G[i+1, j] + p.G[i, j+1] + p.G[i+1, j+1])
        du.P[i, j] = εvol[i, j] + G_av * (P[i, j] - P_shift)

        # Effective Viscosity (optional, for monitoring)
        η_eff[i, j] = τII[i, j] / (2.0 * εII[i, j] + 1e-18)
    end

    # 2. X-Momentum Balance
    for j in 1:nz, i in 1:nx+1
        if i == 1
            du.Vx[i, j] = Vx[i, j] - p.BC.εxx_bg * p.x[i]
        elseif i == nx+1
            du.Vx[i, j] = Vx[i, j] - p.BC.εxx_bg * p.x[i]
        else
            # Internal nodes: -dP/dx + dτxx/dx + dτxz/dz = 0
            dτxz_dz = (τxz[i, j+1] - τxz[i, j]) / dz
            dτxx_P_dx = ( (τxx[i, j] - (P[i, j] - P_shift)) - (τxx[i-1, j] - (P[i-1, j] - P_shift)) ) / dx
            du.Vx[i, j] = dτxx_P_dx + dτxz_dz
        end
    end

    # 3. Z-Momentum Balance
    for j in 1:nz+1, i in 1:nx
        if j == 1
            du.Vz[i, j] = Vz[i, j] - p.BC.εzz_bg * p.z[j]
        elseif j == nz+1
            du.Vz[i, j] = Vz[i, j] - p.BC.εzz_bg * p.z[j]
        else
            # Internal nodes: -dP/dz + dτzz/dz + dτxz/dx - ρg = 0
            dτxz_dx = (τxz[i+1, j] - τxz[i, j]) / dx
            dτzz_P_dz = ( (τzz[i, j] - (P[i, j] - P_shift)) - (τzz[i, j-1] - (P[i, j-1] - P_shift)) ) / dz

            # Density averaging (lin_int replacement)
            ρ_av = 0.5 * (p.ρ[i, j] + p.ρ[i+1, j])

            du.Vz[i, j] = dτzz_P_dz + dτxz_dx - ρ_av * p.g
        end
    end

    return nothing
end

# Setup problem
Nx, Nz = 40, 40
dx, dz = 1.0/Nx, 1.0/Nz
Δ = (dx, dz)

# Boundary and Physical constants
BC = (εxx_bg = 1.0, εzz_bg = -1.0)
P_shift = [0.0]
g = 1.0
Δt = [1e-3]

# Material fields (Must be defined)
ρ   = ones(Nx+1, Nz+1)
η   = ones(Nx+1, Nz+1)
G   = ones(Nx+1, Nz+1)

# History fields (IMPORTANT: These must exist in 'p')
τxx_old = zeros(Nx, Nz)
τzz_old = zeros(Nx, Nz)
τxz_old = zeros(Nx+1, Nz+1)

# ComponentArray for state variables
u = ComponentArray(
    P  = rand(Nx, Nz),
    Vx = rand(Nx+1, Nz),
    Vz = rand(Nx, Nz+1)
)

x = range(-0.5, 0.5, Nx+1)
z = range(-1.0, 0.0, Nz+1)
xc = (x[2:end] .+ x[1:end-1]) ./ 2
zc = (z[2:end] .+ z[1:end-1])

# initial guess consistent with BCs
for (i, xv) in enumerate(x), (j, zv) in enumerate(zc); u.Vx[i, j] = BC.εxx_bg * xv; end
for (i, xv) in enumerate(xc), (j, zv) in enumerate(z); u.Vz[i, j] = BC.εzz_bg * zv; end
u.P .= 0.0

du = similar(u)

# Setup caches
AD_CHUNK = 12
p = (;
    x = range(-0.5, 0.5, Nx+1),
    z = range(-1.0, 0.0, Nz+1),
    Δ, BC, P_shift, g, Δt, ρ, η, G,
    τxx_old, τzz_old, τxz_old,
    cache_εvol = dualcache(zeros(Nx, Nz), AD_CHUNK),
    cache_εxx  = dualcache(zeros(Nx, Nz), AD_CHUNK),
    cache_εzz  = dualcache(zeros(Nx, Nz), AD_CHUNK),
    cache_εxz  = dualcache(zeros(Nx+1, Nz+1), AD_CHUNK),
    cache_εII  = dualcache(zeros(Nx, Nz), AD_CHUNK),
    cache_τxx  = dualcache(zeros(Nx, Nz), AD_CHUNK),
    cache_τzz  = dualcache(zeros(Nx, Nz), AD_CHUNK),
    cache_τxz  = dualcache(zeros(Nx+1, Nz+1), AD_CHUNK),
    cache_τII  = dualcache(zeros(Nx, Nz), AD_CHUNK),
    cache_η_eff= dualcache(zeros(Nx, Nz), AD_CHUNK)
)

# Test call with Float64 to ensure logic is sound
Res!(du, u, p, 0.0)
println("Residual test passed with Float64")

# Now run the Sparsity Detection

detector = SparseConnectivityTracer.TracerSparsityDetector()
jac_sparsity = ADTypes.jacobian_sparsity(
    (du_vec, u_vec) -> Res!(du_vec, u_vec, p, 0.0),
    du,
    u,
    detector
)
println("Jacobian sparsity pattern detected.")

n_total = length(u)

# Create a sparse matrix of zeros
M = Diagonal(zeros(n_total))

jac_prototype = float.(jac_sparsity)  # Convert to actual sparse matrix

f = ODEFunction(Res!;
    mass_matrix = M,
    jac_prototype = jac_prototype
)
t = (0.0, 0.4)

prob = ODEProblem(f, u, t, p, save_everystep = false)
@time integ = init(prob, Rodas5P())


while integ.t < t[end]
    # update Δt dynamically
    integ.p.Δt[1] = integ.dt  # update parameters for next step
    integ.p.τxx_old .= get_tmp(integ.p.cache_τxx, integ.u)
    integ.p.τzz_old .= get_tmp(integ.p.cache_τzz, integ.u)
    integ.p.τxz_old .= get_tmp(integ.p.cache_τxz, integ.u)
    integ.p.P_shift[1] = mean(integ.u.P) # update pressure shift

    println("Current t = $(integ.t)")
    step!(integ)
end

If it’s the identity, don’t pass it.

Where does the profile say it’s spending its time? Note Rosenbrock methods are going to be the wrong choice here, it will certainly be an FBDF type of problem.

Thx for the reply.

M is not the identity, it is zero everywhere. I am just creating a sparse matrix with 0 in the diagonal. I also tried with just an empty sparse matrix, but that didn’t change anything.

I did some profiling of the init() call, with FDBF() this time:

Sorry, it is a bit small :sweat_smile: It is not really helpful for me, but maybe it is for you.

So then you don’t have an ODE? Why not make a NonlinearProblem?

It works well with NonlinearProblem yes, I have tried. There is a time-component in my system though due to elasticity on the stresses. This is discretised internally using explicit Euler, which is what is commonly done. So I just wanted to use an ODEProblem to get the timestepping algorithm because even though it is expressed as an algebraic problem, there is still a time-loop, but I guess I will stay with NonlinearProblem.jl and do a simple timestepping algorith.

I also would like to add the heat equation later and solve them together, so I will have a time derivative at that stage for this elliptic equation.

If the mass matrix is 0 then there’s no time component?

There is one, but it is hidden in the constitutive relationships between deviatoric strain and stresses:

function total_stresses!(τxx, τzz, τxz, τII, εxx, εzz, εxz, τxx_old, τzz_old, τxz_old, η, G, Δt)
    nx, nz = size(εxx)

    # 1. CRITICAL: Initialize the entire τxz array.
    # This prevents UndefRefErrors when reading boundary nodes for τII.
    τxz .= 0.0

    # 2. Update τxx and τzz at cell centers
    for j in 1:nz, i in 1:nx
        # Average viscosity and bulk modulus from 4 surrounding vertices to the center
        η_av = 0.25 * (η[i, j] + η[i+1, j] + η[i, j+1] + η[i+1, j+1])
        G_av = 0.25 * (G[i, j] + G[i+1, j] + G[i, j+1] + G[i+1, j+1])

        χ_av = η_av / (G_av * Δt[1])
        inv_1_χ = 1.0 / (1.0 + χ_av)

        τxx[i, j] = (τxx_old[i, j] * χ_av + 2.0 * η_av * εxx[i, j]) * inv_1_χ
        τzz[i, j] = (τzz_old[i, j] * χ_av + 2.0 * η_av * εzz[i, j]) * inv_1_χ
    end

    # 3. Update τxz at internal vertices
    for j in 2:nz, i in 2:nx
        χ = η[i, j] / (G[i, j] * Δt[1])
        inv_1_χ = 1.0 / (1.0 + χ)
        τxz[i, j] = (τxz_old[i, j] * χ + 2.0 * η[i, j] * εxz[i, j]) * inv_1_χ
    end

    # 4. Compute Second Invariant τII at cell centers
    for j in 1:nz, i in 1:nx
        # Now every index in τxz has been touched by the tracer,
        # so this line won't throw UndefRefError anymore.
        txz_sq_av = 0.25 * (τxz[i, j]^2 + τxz[i+1, j]^2 + τxz[i, j+1]^2 + τxz[i+1, j+1]^2)

        τII[i, j] = sqrt(0.5 * τxx[i, j]^2 + 0.5 * τzz[i, j]^2 + txz_sq_av)
    end

    return nothing
end

There is this χ_av = η_av / (G_av * Δt[1]) that depends on Δt because of elasticity:

\dot{\boldsymbol{\boldsymbol{\varepsilon}}} = \frac{1}{2G} \frac{D \boldsymbol{\tau}}{D t} + \frac{1}{2\eta} \boldsymbol{\tau}

I could also express that as ODEs, and put them in the solver, but I believe this is just making the system bigger where this part can be solved with a simple Euler. This is what is commonly done.

The code works so I believe my approach is correct, just the issue is that currently, the cost of initialising the problem is too high for it to be worth it.