How to debug ODE model aborting midway on sensitivity computation

Hi there,

I have an ODE model that results from the spatial discretization of a PDE including reaction terms. The rate coefficient of the reaction is calculated from an interpolation that is provided externally.

I can run the forward simulation / ODE solve just fine. However, when I want to compute sensitivities wrt. interpolation coefficients, the solver aborts at different time points (“dt was forced below floating point epsilon …”). These time points do not coincide with interpolation nodes.

Sensitivities wrt. other system parameters work. Only interpolation coefficients cause these problems. So far, I was not able to distill it down to a tractable MWE.

Do you have some ideas on how to debug this?

Which sensitivity method?

I’m using sensealg=ForwardDiffSensitivity() in the ODE solve() call and ForwardDiff.gradient(loss, params) as the outer-most call.

Can I get a stack trace on that or a code to call?

This is the smallest working example I could produce:

using DifferentialEquations, PolynomialBases
using ComponentArrays, PreallocationTools
using SparseArrays, Sparspak
using ForwardDiff

using SciMLSensitivity
using Random

function legendre_inv_mass(b::NodalBasis{T}) where {T}
    V = legendre_vandermonde(b)

    for n in axes(V, 1)
        V[:, n] .= @view(V[:, n]) * sqrt((2*n - 1) / 2)
    end

    return V * V'
end

function bilinear_interpolate(z, x_coords, y_coords, x, y)
    if (x < x_coords[1]) || (x > x_coords[end]) || (y < y_coords[1]) || (y > y_coords[end])
        return zero(eltype(z))
    end

    i = searchsortedlast(x_coords, x)
    j = searchsortedlast(y_coords, y)

    i = clamp(i, 1, length(x_coords)-1)
    j = clamp(j, 1, length(y_coords)-1)

    x1, x2 = x_coords[i], x_coords[i+1]
    y1, y2 = y_coords[j], y_coords[j+1]

    wx = (x - x1) / (x2 - x1)
    wy = (y - y1) / (y2 - y1)

    return (1-wx)*(1-wy)*z[i,j] + wx*(1-wy)*z[i+1,j] + 
           (1-wx)*wy*z[i,j+1] + wx*wy*z[i+1,j+1]
end

struct TemporalSpatialInterpolation
    time_points::Vector{Float64}
    space_points::Vector{Float64}
end

function init(model::TemporalSpatialInterpolation, n_comp::Int)
    return zeros(Float64, n_comp * length(model.time_points) * length(model.space_points))
end

function rate!(dq::qType, model::TemporalSpatialInterpolation, n_comp::Int, p::pType, pos, t) where {qType, pType}
    np = length(model.time_points) * length(model.space_points)
    for i = 1:n_comp
        pp = view(p, (i-1) * np + 1:i * np)
        ppp = reshape(pp, (length(model.time_points), length(model.space_points)))
        for j in axes(dq, 2)
            dq[i, j] = bilinear_interpolate(ppp, model.time_points, model.space_points, t, pos[j])
        end
    end
    nothing
end

struct FlowModel{TransferType}
    n_comp::Int
    n_elements::Int
    n_points::Int
    n_dof_per_element::Int
    n_dof_per_phase::Int
    lgl_basis::LobattoLegendre{Float64}
    D::Array{Float64, 2}
    M_inv::Array{Float64, 2}
    transfer_model::TransferType

    function FlowModel(n_comp::Int, n_elements::Int, n_degree::Int, trans_mod::TM) where {TM}
        lgl_basis = LobattoLegendre(n_degree)

        n_points = n_degree + 1
        n_dof_per_element = n_points * n_comp
        n_dof_per_phase = n_elements * n_dof_per_element

        M_inv = legendre_inv_mass(lgl_basis)
        new{TM}(n_comp, n_elements, n_points, n_dof_per_element, n_dof_per_phase, lgl_basis, lgl_basis.D, M_inv, trans_mod)
    end
end

function element_start_index(phase::Symbol, element::Int, component::Int, disc::FlowModel)
    n_points = disc.n_points
    n_dof_per_comp = disc.n_elements * n_points
    n_dof_per_phase = disc.n_dof_per_phase
    n_dof_per_element = disc.n_dof_per_element

    if phase == :phase1
        return (component - 1) * n_dof_per_comp + (element - 1) * n_points + 1
    elseif phase == :phase2
        ip = 2
    elseif phase == :phase3
        ip = 3
    end
    return (ip - 1) * n_dof_per_phase + (element - 1) * n_dof_per_element + (component - 1) * n_points + 1
end

function element_end_index(phase::Symbol, element::Int, component::Int, disc::FlowModel)
    n_points = disc.n_points
    n_dof_per_comp = disc.n_elements * n_points
    n_dof_per_phase = disc.n_dof_per_phase
    n_dof_per_element = disc.n_dof_per_element

    if phase == :phase1
        return (component - 1) * n_dof_per_comp + element * n_points
    elseif phase == :phase2
        ip = 2
    elseif phase == :phase3
        ip = 3
    end
    return (ip - 1) * n_dof_per_phase + (element - 1) * n_dof_per_element + component * n_points
end

function num_dofs(disc::FlowModel)
    return disc.n_points * disc.n_elements * disc.n_comp * 3
end

function outlet_index(model::FlowModel, comp::Int)
    n_elements = model.n_elements
    return element_end_index(:phase1, n_elements, comp, model)
end

function rhs!(dc::dcType, 
            c::cType, 
            params::pType, 
            t::tType,
            disc::FlowModel, 
            cache
            )  where {pType, tType<:Real, dcType<:AbstractVector{<:Real}, cType<:AbstractVector{<:Real}}

    inlet = [0.01, 0.02]
    flow_rate = 5.5e-8
    velocity = flow_rate / (π * (4e-3)^2 * 0.45)
    Δz_nondim = 1.0 / disc.n_elements
    Δz = 100 * 1e-3 / disc.n_elements
    two_over_Δz = 2.0 / Δz

    n_points = disc.n_points
    n_elements = disc.n_elements

    # Get some scratch memory
    TypeState = eltype(c)
    interp_buffer = get_tmp(cache[1], c)
    coord_buffer = cache[2]

    g_star_left::TypeState = 0.0
    g_star_right::TypeState = 0.0

    β_c = (1.0 - 0.45) / 0.45
    β_p = (1.0 - 0.75) / 0.75

    transfer_coeff = exp10(params.transfer_coeff)

    for comp = 1:disc.n_comp

        factor = β_c * 3.0 / (5e-5 / 2) * transfer_coeff

        for i = 1:n_elements
            idx_start = element_start_index(:phase1, i, comp, disc)
            idx_end = element_end_index(:phase1, i, comp, disc)
            idx_start_p2 = element_start_index(:phase2, i, comp, disc)
            idx_end_p2 = element_end_index(:phase2, i, comp, disc)
            local_p1 = view(c, idx_start:idx_end)
            local_p2 = view(c, idx_start_p2:idx_end_p2)

            if i == 1
                g_star_left = velocity * inlet[comp]
            else
                idx = element_end_index(:phase1, i - 1, comp, disc)
                g_star_left = velocity * c[idx]
            end

            g_star_right = velocity * local_p1[end]

            dc[idx_start:idx_end] .= -two_over_Δz .* velocity .* (disc.D * local_p1)
            dc[idx_start:idx_end] .+= two_over_Δz .* (disc.M_inv[:, 1] .* (-velocity * local_p1[1] + g_star_left) .+ disc.M_inv[:, end] .* (velocity * local_p1[end] - g_star_right))
            dc[idx_start:idx_end] .+= factor .* (local_p2 .- local_p1)
        end
    end

    factor2 = 3.0 / ((5e-5 / 2) * 0.75)
    for i = 1:n_elements
        idx_start_p2 = element_start_index(:phase2, i, 1, disc)
        idx_end_p2 = element_end_index(:phase2, i, disc.n_comp, disc)

        p2_element = reshape(view(c, idx_start_p2:idx_end_p2), n_points, disc.n_comp)

        coord_buffer .= (disc.lgl_basis.nodes .+ 1) .* Δz_nondim ./ 2 .+ (i-1) .* Δz_nondim
        rate!(interp_buffer, disc.transfer_model, disc.n_comp, params.interp_params, coord_buffer, t) 

        for j = 1:n_points
            for comp = 1:disc.n_comp
                idx_start_comp_p2 = element_start_index(:phase2, i, comp, disc) - 1
                idx_start_comp_p3 = element_start_index(:phase3, i, comp, disc) - 1

                idx_start_liquid = element_start_index(:phase1, i, comp, disc)
                idx_end_liquid = element_end_index(:phase1, i, comp, disc)
                loc_c = view(c, idx_start_liquid:idx_end_liquid)

                dc[idx_start_comp_p2 + j] = factor2 * transfer_coeff * (loc_c[j] - p2_element[j, comp]) - β_p * interp_buffer[comp, j]
                dc[idx_start_comp_p3 + j] = interp_buffer[comp, j]
            end
        end
    end
    nothing
end

function grad_fail()
    n_comp = 2

    time_nodes = [36.0, 360.0, 450.0, 1870.0, 1900.0, 2100.0]

    interp = TemporalSpatialInterpolation(time_nodes, Float64[0.0, 1.0])

    initInterp = init(interp, n_comp)
    randn!(Xoshiro(12), initInterp)
    init_params = ComponentArray(interp_params=initInterp, transfer_coeff=-6.2)

    n_elements = 5
    n_degree = 3

    model = FlowModel(n_comp, n_elements, n_degree, interp)

    println(init_params)
    println("#Params: $(length(init_params))")

    caches = (
        DiffCache(Matrix{Float64}(undef, (model.n_comp, model.n_points)), model.n_comp * model.n_points; levels=2),
        Array{Float64}(undef, model.n_points)
        )

    time_points = collect(LinRange(0.0, 2200.0, 100))
    tend = time_points[end] + 100.0
    tspan = (0.0, tend)

    save_idx = [outlet_index(model, 1), outlet_index(model, 2)]
    
    # Just to provide some sparsity pattern
    jac0 = spzeros(Float64, num_dofs(model), num_dofs(model))
    for i in 1:num_dofs(model)
        for j in 1:num_dofs(model)
            jac0[i, j] = 1.0
        end
    end

    function loss(p)
        nd = num_dofs(model)
        u0 = zeros(eltype(p), nd)

        prob = ODEProblem{true, SciMLBase.FullSpecialize}(ODEFunction{true, SciMLBase.FullSpecialize}(
            (du, u, pp, t) -> let model = model, caches = caches
                rhs!(du, u, pp, t, model, caches)
            end; jac_prototype=jac0),
            u0, tspan, p)

        solver = (eltype(p) <: Float64) ? QNDF(autodiff=true) : QNDF(autodiff=true, linsolve=SparspakFactorization())
        sol = solve(prob,
                    solver,
                    abstol=1e-8,
                    reltol=1e-4,
                    save_start=false,
                    save_end=false,
                    save_idxs=save_idx,
                    saveat=time_points,
                    sensealg=ForwardDiffSensitivity())

        if SciMLBase.successful_retcode(sol.retcode)
            val = sum(sum(sol.u))
        else
            val = Inf
        end

        return val
    end

    loss(init_params)

    return ForwardDiff.gradient(loss, init_params)
end

Running grad_fail() produces several of these:

┌ Warning: At t=2252.041792732512, dt was forced below floating point epsilon 4.547473508864641e-13, and step error estimate = 6.280201993524826e-6. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of ForwardDiff.Dual{ForwardDiff.Tag{var"#loss#23"{SparseMatrixCSC{Float64, Int64}, Vector{Int64}, Tuple{Float64, Float64}, Tuple{DiffCache{Matrix{Float64}, Vector{Float64}}, Vector{Float64}}, Vector{Float64}, FlowModel{TemporalSpatialInterpolation}}, Float64}, Float64, 9}).
└ @ SciMLBase ~/.julia/packages/SciMLBase/m1Jrs/src/integrator_interface.jl:623

Julia version 1.10.4+0.aarch64.apple.darwin14 and package versions:

[b0b7db55] ComponentArrays v0.15.27
[0c46a032] DifferentialEquations v7.16.1
[f6369f11] ForwardDiff v0.10.38
[c74db56a] PolynomialBases v0.4.22
[d236fae5] PreallocationTools v0.4.27
[0bca4576] SciMLBase v2.87.0
[1ed8b502] SciMLSensitivity v7.79.0
[47a9eef4] SparseDiffTools v2.23.1
[e56a9233] Sparspak v0.3.11
[9a3f8284] Random
[2f01184e] SparseArrays v1.10.0

I think it is connected to the SparspakFactorization: It works when using the dense LU (but this is unbearably slow). Note that the “full” sparsity pattern in the snippet is only for provoking the problem. The real code has a reasonable pattern.

1 Like

@jClugstor take a look at the diff here? If the issue is isolated to Sparse park factorization then it can probably be isolated to just LinearSolve’s forwarddiff handling

1 Like

I can’t reproduce the warnings with latest versions of packages. Any particular reason you’re on older versions of SciMLBase, SciMLSensitvity, SparseDiffTools, etc. ?

I’ve had other packages break interfaces / cause problems after updating multiple times in the past. So I decided to freeze packages to get this project done.

Now, I’ve updated to the most recent Julia (1.11.6) and packages:

[b0b7db55] ComponentArrays v0.15.29
[0c46a032] DifferentialEquations v7.16.1
[f6369f11] ForwardDiff v1.2.0
[c74db56a] PolynomialBases v0.4.25
[d236fae5] PreallocationTools v0.4.34
[1ed8b502] SciMLSensitivity v7.89.0
[e56a9233] Sparspak v0.3.14
[9a3f8284] Random v1.11.0
[2f01184e] SparseArrays v1.11.0

Indeed, the problem is gone and the simulation runs.

However, when comparing a single parameter sensitivity that works for both settings (i.e., transfer_coeff), the new packages run substantially slower:
With the old packages, I get:

0.613176 seconds (1.06 M allocations: 332.460 MiB, 1.50% gc time)

whereas the new packages yield

6.585257 seconds (12.83 M allocations: 29.857 GiB, 19.13% gc time)

Both measurements taken after running @time grad_fail() multiple times.

Any idea what causes these massive allocations (30 GiB vs 0.32 GiB) in the new versions?

If it helps, poor mans sampling profiling (i.e., repeatedly running the code, hitting CTRL+C during execution, and looking at frequency of functions in produced stack traces) hints at

  [5] partials_to_list
    @ ~/.julia/packages/LinearSolve/tOoFP/ext/LinearSolveForwardDiffExt.jl:254 [inlined]
  [6] xp_linsolve_rhs(uu::Vector{…}, ∂_A::SparseMatrixCSC{…}, ∂_b::Vector{…})
    @ LinearSolveForwardDiffExt ~/.julia/packages/LinearSolve/tOoFP/ext/LinearSolveForwardDiffExt.jl:88
  [7] linearsolve_forwarddiff_solve(::LinearSolveForwardDiffExt.DualLinearCache{…}, ::SparspakFactorization; kwargs::@Kwargs{…})
    @ LinearSolveForwardDiffExt ~/.julia/packages/LinearSolve/tOoFP/ext/LinearSolveForwardDiffExt.jl:68
  [8] linearsolve_forwarddiff_solve
    @ ~/.julia/packages/LinearSolve/tOoFP/ext/LinearSolveForwardDiffExt.jl:49 [inlined]
  [9] solve!(::LinearSolveForwardDiffExt.DualLinearCache{…}, ::SparspakFactorization; kwargs::@Kwargs{…})
    @ LinearSolveForwardDiffExt ~/.julia/packages/LinearSolve/tOoFP/ext/LinearSolveForwardDiffExt.jl

Yeah that’s Jadon’s part. He’s currently optimizing it some more:

And we should add this as a test case.