Played around with the agent a bit, seems I got this down to 4 seconds now with the latest versions if I use the new vjp hooks:
| Config |
Time |
Allocs |
GiB |
GC% |
| OLD sparse VJP + EnzymeVJP |
12.2s |
31M |
7.0 |
6% |
| Dense-block VJP + ForwardDiff + vjp_p |
6.5s |
26M |
7.5 |
12% |
| Analytical VJP + vjp_p |
4.0s |
9.7M |
4.4 |
10% |
Fastest code:
# Fully non-allocating VJP with analytical derivatives
# No ForwardDiff in the VJP hot path at all.
# Computes J^T*v and (∂f/∂p)^T*v directly via hand-coded derivatives of binding_flux.
import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
using LinearAlgebra, LinearSolve
using PolynomialBases, ComponentArrays, SparseArrays
import ForwardDiff
import Zygote
function legendre_inv_mass(b::NodalBasis{T}) where {T}
V = legendre_vandermonde(b)
for n in axes(V, 1)
V[:, n] .= @view(V[:, n]) * sqrt((2 * n - 1) / 2)
end
return V * V'
end
struct ModulatedBinding
nComp::Int
nBindingComps::Int
end
function binding_flux!(iso, model::ModulatedBinding, cp, cs, p)
n_points = size(cp, 1)
n_bind = model.nBindingComps
@inbounds for j in 1:n_bind
for i in 1:n_points
q_free = 1.0
for k in 1:n_bind
q_free -= cs[i, k] / p.qmax[k]
end
ads = p.ka[j] * exp((cp[i, end] - p.salt_ref) * p.gamma[j]) *
cp[i, j] * p.qmax[j] * q_free
des = p.kd[j] * (cp[i, end] / p.salt_ref)^p.beta[j] * cs[i, j]
iso[i, j] = ads - des
end
end
return nothing
end
# Allocating version only used for Jacobian prototype computation (setup, not hot path)
function binding_flux(model::ModulatedBinding, cp::cpType, cs::csType, p::pType) where {cpType, csType, pType}
@views begin
q_free = 1.0 .- sum(cs[:, 1:model.nBindingComps] ./ reshape(p.qmax, 1, :), dims = 2)
ads = reshape(p.ka, 1, :) .* exp.((cp[:, end] .- p.salt_ref) .* reshape(p.gamma, 1, :)) .* cp[:, 1:model.nBindingComps] .* reshape(p.qmax, 1, :) .* q_free
des = reshape(p.kd, 1, :) .* (cp[:, end] ./ p.salt_ref) .^ reshape(p.beta, 1, :) .* cs[:, 1:model.nBindingComps]
return ads .- des
end
end
# =========================================================================
# Analytical VJP of binding_flux: computes J_cp^T * w and J_cs^T * w
# directly without forming the Jacobian matrix.
#
# iso[i,j] = ka[j]*E[i,j]*cp[i,j]*qmax[j]*qf[i] - kd[j]*D[i,j]*cs[i,j]
# where E[i,j] = exp((cp[i,end]-salt_ref)*gamma[j])
# D[i,j] = (cp[i,end]/salt_ref)^beta[j]
# qf[i] = 1 - Σ_k cs[i,k]/qmax[k]
#
# Writes into Jv_cp (n_comp*n_points) and Jv_cs (n_comp*n_points).
# w is (n_bind*n_points) = the combined weight vector.
# =========================================================================
function binding_flux_vjp_state!(Jv_cp, Jv_cs, w, model::ModulatedBinding, cp, cs, p)
n_pts = size(cp, 1)
n_bind = model.nBindingComps
n_comp = model.nComp
fill!(Jv_cp, 0.0)
fill!(Jv_cs, 0.0)
@inbounds for i in 1:n_pts
# Precompute q_free for this point
qf = 1.0
for k in 1:n_bind
qf -= cs[i, k] / p.qmax[k]
end
for j in 1:n_bind
w_ij = w[(j - 1) * n_pts + i] # weight for output iso[i,j]
E_ij = exp((cp[i, n_comp] - p.salt_ref) * p.gamma[j])
D_ij = (cp[i, n_comp] / p.salt_ref)^p.beta[j]
ads_base = p.ka[j] * E_ij # common factor for adsorption terms
# ∂iso[i,j]/∂cp[i,j] = ka[j]*E[i,j]*qmax[j]*qf[i]
Jv_cp[(j - 1) * n_pts + i] += ads_base * p.qmax[j] * qf * w_ij
# ∂iso[i,j]/∂cp[i,end] = ka[j]*gamma[j]*E[i,j]*cp[i,j]*qmax[j]*qf[i]
# - kd[j]*beta[j]*(cp[i,end]/salt_ref)^(beta[j]-1)/salt_ref*cs[i,j]
diso_dcpend = ads_base * p.gamma[j] * cp[i, j] * p.qmax[j] * qf -
p.kd[j] * p.beta[j] * D_ij / cp[i, n_comp] * cs[i, j]
Jv_cp[(n_comp - 1) * n_pts + i] += diso_dcpend * w_ij
# ∂iso[i,j]/∂cs[i,m]:
# for all binding m: -ka[j]*E[i,j]*cp[i,j]*qmax[j]/qmax[m]
# additionally if m==j: -kd[j]*D[i,j]
common_cs = -ads_base * cp[i, j] * p.qmax[j]
for m in 1:n_bind
diso_dcs_m = common_cs / p.qmax[m]
if m == j
diso_dcs_m -= p.kd[j] * D_ij
end
Jv_cs[(m - 1) * n_pts + i] += diso_dcs_m * w_ij
end
end
end
return nothing
end
# =========================================================================
# Analytical VJP of binding_flux w.r.t. parameters: computes J_p^T * w
# directly without forming the Jacobian matrix.
#
# Parameters: ka(n_bind), kd(n_bind), qmax(n_bind), gamma(n_bind), beta(n_bind), salt_ref(1)
# =========================================================================
function binding_flux_vjp_params!(Jv_p, w, model::ModulatedBinding, cp, cs, p)
n_pts = size(cp, 1)
n_bind = model.nBindingComps
n_comp = model.nComp
fill!(Jv_p, 0.0)
@inbounds for i in 1:n_pts
qf = 1.0
for k in 1:n_bind
qf -= cs[i, k] / p.qmax[k]
end
for j in 1:n_bind
w_ij = w[(j - 1) * n_pts + i]
E_ij = exp((cp[i, n_comp] - p.salt_ref) * p.gamma[j])
D_ij = (cp[i, n_comp] / p.salt_ref)^p.beta[j]
ads_full = p.ka[j] * E_ij * cp[i, j] * p.qmax[j] * qf
# ∂iso/∂ka[j] = E[i,j]*cp[i,j]*qmax[j]*qf[i]
Jv_p.ka[j] += (ads_full / p.ka[j]) * w_ij
# ∂iso/∂kd[j] = -D[i,j]*cs[i,j]
Jv_p.kd[j] += (-D_ij * cs[i, j]) * w_ij
# ∂iso/∂qmax[m]:
# if m==j: ka[j]*E*cp[i,j]*(qf + cs[i,j]/qmax[j])
# for all binding m: ka[j]*E*cp[i,j]*qmax[j]*cs[i,m]/qmax[m]^2
for m in 1:n_bind
diso_dqmax_m = p.ka[j] * E_ij * cp[i, j] * p.qmax[j] * cs[i, m] / p.qmax[m]^2
if m == j
diso_dqmax_m += p.ka[j] * E_ij * cp[i, j] * qf
end
Jv_p.qmax[m] += diso_dqmax_m * w_ij
end
# ∂iso/∂gamma[j] = ka[j]*(cp[i,end]-salt_ref)*E[i,j]*cp[i,j]*qmax[j]*qf[i]
Jv_p.gamma[j] += ads_full * (cp[i, n_comp] - p.salt_ref) * w_ij
# ∂iso/∂beta[j] = -kd[j]*log(cp[i,end]/salt_ref)*D[i,j]*cs[i,j]
Jv_p.beta[j] += (-p.kd[j] * log(cp[i, n_comp] / p.salt_ref) * D_ij * cs[i, j]) * w_ij
# ∂iso/∂salt_ref:
# -ka[j]*gamma[j]*E[i,j]*cp[i,j]*qmax[j]*qf[i]
# +kd[j]*beta[j]*D[i,j]/salt_ref*cs[i,j]
Jv_p.salt_ref += (-ads_full * p.gamma[j] +
p.kd[j] * p.beta[j] * D_ij / p.salt_ref * cs[i, j]) * w_ij
end
end
return nothing
end
struct AxialFlowModel{BindModel}
n_comp::Int
cross_section_area::Float64
length::Float64
film_diffusion_coeff::Vector{Float64}
col_porosity::Float64
par_porosity::Float64
par_radius::Float64
binding_model::BindModel
n_elements::Int
n_points::Int
n_dof_per_element::Int
n_dof_per_phase::Int
lgl_basis::LobattoLegendre{Float64}
D::Array{Float64, 2}
M_inv::Array{Float64, 2}
end
function element_start_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
if phase == :liquid
return (component - 1) * disc.n_elements * disc.n_points + (element - 1) * disc.n_points + 1
elseif phase == :particle
ip = 2
elseif phase == :solid
ip = 3
end
return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + (component - 1) * disc.n_points + 1
end
function element_end_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
if phase == :liquid
return (component - 1) * disc.n_elements * disc.n_points + element * disc.n_points
elseif phase == :particle
ip = 2
elseif phase == :solid
ip = 3
end
return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + component * disc.n_points
end
function rhs_jac_noalloc!(dc, c, bind_params, jac_wo_isotherm, flow_rate, t,
model::AxialFlowModel{BindModel}, inlet_buf, iso_buf,
load_len, elute_len, load_conc, salt_eq, salt_start) where {BindModel}
if t < load_len
for i in 1:length(load_conc)
inlet_buf[i] = load_conc[i]
end
inlet_buf[end] = salt_eq
elseif t < load_len + elute_len
for i in 1:length(load_conc)
inlet_buf[i] = 0.0
end
inlet_buf[end] = salt_start
else
fill!(inlet_buf, 0.0)
end
Δz = model.length / model.n_elements
two_over_Δz = 2.0 / Δz
velocity = flow_rate / (model.cross_section_area * model.col_porosity)
mul!(dc, jac_wo_isotherm, c)
@views for comp in 1:model.n_comp
idx_start = element_start_index(:liquid, 1, comp, model)
idx_end = element_end_index(:liquid, 1, comp, model)
inlet_val = inlet_buf[comp]
for j in idx_start:idx_end
dc[j] += two_over_Δz * model.M_inv[j - idx_start + 1, 1] * velocity * inlet_val
end
end
β_p = (1.0 - model.par_porosity) / model.par_porosity
for i in 1:model.n_elements
idx_start_particle = element_start_index(:particle, i, 1, model)
idx_end_particle = element_end_index(:particle, i, model.n_comp, model)
idx_start_solid = element_start_index(:solid, i, 1, model)
idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
cp_element = reshape(view(c, idx_start_particle:idx_end_particle), model.n_points, model.n_comp)
cs_element = reshape(view(c, idx_start_solid:idx_end_solid), model.n_points, model.n_comp)
binding_flux!(iso_buf, model.binding_model, cp_element, cs_element, bind_params)
for comp in 1:model.n_comp
p_start = element_start_index(:particle, i, comp, model)
s_start = element_start_index(:solid, i, comp, model)
if comp <= model.binding_model.nBindingComps
@inbounds for j in 1:model.n_points
dc[p_start + j - 1] -= β_p * iso_buf[j, comp]
dc[s_start + j - 1] = iso_buf[j, comp]
end
else
@inbounds for j in 1:model.n_points
dc[s_start + j - 1] = 0.0
end
end
end
end
return nothing
end
# Full Jacobian (only used for setup / jac_prototype, not hot path)
function jac!(J::JType, model::ModulatedBinding, p::pType, row_offset,
idx_start_liquid, idx_start_solid, cp, cs) where {JType, pType}
n_points = size(cp, 1)
n_bind = model.nBindingComps
n_comp = model.nComp
ForwardDiff.jacobian!(view(J, row_offset:(row_offset + n_bind * n_points - 1),
idx_start_liquid:(idx_start_liquid + n_comp * n_points - 1)),
cp -> reshape(binding_flux(model, reshape(cp, n_points, n_comp), cs, p), :),
reshape(cp, :))
ForwardDiff.jacobian!(view(J, row_offset:(row_offset + n_bind * n_points - 1),
idx_start_solid:(idx_start_solid + n_comp * n_points - 1)),
cs -> reshape(binding_flux(model, cp, reshape(cs, n_points, n_comp), p), :),
reshape(cs, :))
return nothing
end
function jac!(J::JType, c::cType, bind_params::pType, flow_rate, t,
model::AxialFlowModel{BindModel}, with_isotherm::Bool) where {JType, cType, pType, BindModel}
velocity = flow_rate / (model.cross_section_area * model.col_porosity)
n_points = model.n_points
n_elements = model.n_elements
Δz = model.length / n_elements
two_over_Δz = 2.0 / Δz
β_c = (1.0 - model.col_porosity) / model.col_porosity
β_p = (1.0 - model.par_porosity) / model.par_porosity
J .= 0.0
@views begin
for comp in 1:model.n_comp
bulk_pore_factor = β_c * 3.0 / model.par_radius * model.film_diffusion_coeff[comp]
for i in 1:n_elements
cur_elem_start = element_start_index(:liquid, i, comp, model)
cur_elem_end = element_end_index(:liquid, i, comp, model)
idx_start_cp = element_start_index(:particle, i, comp, model)
J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .= model.D
J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .*= two_over_Δz .* (-velocity)
for j in 0:(n_points - 1)
J[cur_elem_start + j, cur_elem_start + j] -= bulk_pore_factor
J[cur_elem_start + j, idx_start_cp + j] = bulk_pore_factor
end
J[cur_elem_start:cur_elem_end, cur_elem_start] .+= -velocity * two_over_Δz * model.M_inv[:, 1]
J[cur_elem_start:cur_elem_end, cur_elem_end] .+= velocity * two_over_Δz * model.M_inv[:, end]
if i > 1
J[cur_elem_start:cur_elem_end, element_end_index(:liquid, i - 1, comp, model)] .+= velocity * two_over_Δz * model.M_inv[:, 1]
end
J[cur_elem_start:cur_elem_end, cur_elem_end] .-= velocity * two_over_Δz * model.M_inv[:, end]
end
end
par_factor = 3.0 / (model.par_radius * model.par_porosity)
for comp in 1:model.n_comp
for i in 1:n_elements
idx_start_comp_particle = element_start_index(:particle, i, comp, model)
idx_start_comp_liquid = element_start_index(:liquid, i, comp, model)
for j in 0:(n_points - 1)
J[idx_start_comp_particle + j, idx_start_comp_particle + j] = -par_factor * model.film_diffusion_coeff[comp]
J[idx_start_comp_particle + j, idx_start_comp_liquid + j] = par_factor * model.film_diffusion_coeff[comp]
end
end
end
if with_isotherm
for i in 1:n_elements
idx_start_liquid = element_start_index(:particle, i, 1, model)
idx_end_liquid = element_end_index(:particle, i, model.n_comp, model)
idx_start_solid = element_start_index(:solid, i, 1, model)
idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
c_element = reshape(view(c, idx_start_liquid:idx_end_liquid), n_points, model.n_comp)
cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)
jac!(J, model.binding_model, bind_params, idx_start_solid, idx_start_liquid, idx_start_solid, c_element, cs_element)
J[idx_start_liquid:idx_end_liquid, :] .-= β_p * J[idx_start_solid:idx_end_solid, :]
end
end
end
return nothing
end
function init_state!(u0, conc_liquid, conc_solid, model::AxialFlowModel)
for i in 1:model.n_elements
for comp in 1:model.n_comp
u0[element_start_index(:liquid, i, comp, model):element_end_index(:liquid, i, comp, model)] .= conc_liquid[comp]
u0[element_start_index(:particle, i, comp, model):element_end_index(:particle, i, comp, model)] .= conc_liquid[comp]
u0[element_start_index(:solid, i, comp, model):element_end_index(:solid, i, comp, model)] .= conc_solid[comp]
end
end
return nothing
end
function setup_problem()
n_elements = 3
n_degree = 1
bind_params = ComponentArray(
ka = [4.0, 5.5, 3.0] .* 1e-2,
kd = [3.2, 22.0, 9.3] .* 1e-3,
qmax = [3.0, 2.0, 6.5] .* 10.0,
gamma = [-1.0, -0.5, 0.2],
beta = [0.91, 0.82, 1.3],
salt_ref = 1.0
)
lgl_basis = LobattoLegendre(n_degree)
M_inv = legendre_inv_mass(lgl_basis)
model = AxialFlowModel(
4, 1.0 / 0.37, 0.014, fill(6.9e-6, 4), 0.37, 0.75, 45e-6,
ModulatedBinding(4, 3),
n_elements, n_degree + 1, (n_degree + 1) * 4, n_elements * (n_degree + 1) * 4, lgl_basis, lgl_basis.D, M_inv
)
save_idxs = [element_end_index(:liquid, n_elements, i, model) for i in 1:model.n_comp]
salt_eq = 1.0
salt_start = 1.5
load_len = 10.0
elute_len = 1410.0
load_conc = [1.0, 1.0, 1.0]
t_stop = [load_len, load_len + elute_len]
tspan = (0.0, load_len + elute_len)
num_dofs = model.n_points * model.n_elements * model.n_comp * 3
u0 = zeros(Float64, num_dofs)
init_state!(u0, [0.0, 0.0, 0.0, salt_eq], [0.0, 0.0, 0.0, salt_eq], model)
flow_rate = 3.45 / 60 / 100
inlet_buf = zeros(Float64, model.n_comp)
iso_buf = zeros(Float64, model.n_points, model.binding_model.nBindingComps)
jac_wo_iso = spzeros(length(u0), length(u0))
jac!(jac_wo_iso, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, false)
jac_wo_iso_T = sparse(jac_wo_iso')
jac0 = spzeros(length(u0), length(u0))
jac!(jac0, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, true)
ode_jac = let model = model
function (J, u, p, t)
jac!(J, u, p, flow_rate, t, model, true)
return nothing
end
end
n_pts = model.n_points
n_cmp = model.n_comp
n_bnd = model.binding_model.nBindingComps
n_bnd_pts = n_bnd * n_pts
n_cmp_pts = n_cmp * n_pts
β_p = (1.0 - model.par_porosity) / model.par_porosity
# Pre-allocate VJP buffers
w_buf = zeros(n_bnd_pts)
Jv_cp_buf = zeros(n_cmp_pts)
Jv_cs_buf = zeros(n_cmp_pts)
vjp_fn = let jac_wo_iso_T = jac_wo_iso_T, w_buf = w_buf,
Jv_cp_buf = Jv_cp_buf, Jv_cs_buf = Jv_cs_buf,
model = model, β_p = β_p,
n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
function (Jv, v, u, p, t)
# Constant part
mul!(Jv, jac_wo_iso_T, v)
# Isotherm correction per element (analytical, zero-allocation)
for i in 1:model.n_elements
cp_s = element_start_index(:particle, i, 1, model)
cp_e = element_end_index(:particle, i, n_cmp, model)
cs_s = element_start_index(:solid, i, 1, model)
cs_e = element_end_index(:solid, i, n_cmp, model)
# w = v[solid_binding] - β_p * v[particle_binding]
@inbounds for k in 1:n_bnd_pts
w_buf[k] = v[cs_s + k - 1] - β_p * v[cp_s + k - 1]
end
cp_mat = reshape(view(u, cp_s:cp_e), n_pts, n_cmp)
cs_mat = reshape(view(u, cs_s:cs_e), n_pts, n_cmp)
# Analytical J_cp^T * w and J_cs^T * w
binding_flux_vjp_state!(Jv_cp_buf, Jv_cs_buf, w_buf,
model.binding_model, cp_mat, cs_mat, p)
# Accumulate into Jv
@inbounds for k in 1:(n_cmp * n_pts)
Jv[cp_s + k - 1] += Jv_cp_buf[k]
Jv[cs_s + k - 1] += Jv_cs_buf[k]
end
end
return nothing
end
end
# Pre-allocate vjp_p buffers
w_p_buf = zeros(n_bnd_pts)
Jv_p_elem = ComponentArray(
ka = zeros(n_bnd), kd = zeros(n_bnd), qmax = zeros(n_bnd),
gamma = zeros(n_bnd), beta = zeros(n_bnd), salt_ref = 0.0)
vjp_p_fn = let w_p_buf = w_p_buf, Jv_p_elem = Jv_p_elem,
model = model, β_p = β_p,
n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
function (out, λ, y, p, t)
fill!(out, 0.0)
for i in 1:model.n_elements
cp_s = element_start_index(:particle, i, 1, model)
cp_e = element_end_index(:particle, i, n_cmp, model)
cs_s = element_start_index(:solid, i, 1, model)
cs_e = element_end_index(:solid, i, n_cmp, model)
@inbounds for k in 1:n_bnd_pts
w_p_buf[k] = λ[cs_s + k - 1] - β_p * λ[cp_s + k - 1]
end
cp_mat = reshape(view(y, cp_s:cp_e), n_pts, n_cmp)
cs_mat = reshape(view(y, cs_s:cs_e), n_pts, n_cmp)
# Analytical (∂binding_flux/∂p)^T * w
binding_flux_vjp_params!(Jv_p_elem, w_p_buf,
model.binding_model, cp_mat, cs_mat, p)
out .+= Jv_p_elem
end
return out
end
end
# JVP: J * v = jac_wo_iso * v + C_iso * v (analytical)
iso_effect_buf = zeros(n_bnd_pts)
jvp_fn = let jac_wo_iso = jac_wo_iso, model = model, β_p = β_p,
iso_effect_buf = iso_effect_buf,
n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
function (Jv, v, u, p, t)
mul!(Jv, jac_wo_iso, v)
# For JVP we still need the Jacobian blocks (can't easily avoid).
# Use a simple forward-mode: iso_effect = J_cp*v_cp + J_cs*v_cs
# Computed analytically per element.
for i in 1:model.n_elements
cp_s = element_start_index(:particle, i, 1, model)
cp_e = element_end_index(:particle, i, n_cmp, model)
cs_s = element_start_index(:solid, i, 1, model)
cs_e = element_end_index(:solid, i, n_cmp, model)
cp_mat = reshape(view(u, cp_s:cp_e), n_pts, n_cmp)
cs_mat = reshape(view(u, cs_s:cs_e), n_pts, n_cmp)
v_cp = view(v, cp_s:cp_e)
v_cs = view(v, cs_s:cs_e)
n_bind = model.binding_model.nBindingComps
fill!(iso_effect_buf, 0.0)
@inbounds for ii in 1:n_pts
qf = 1.0
for k in 1:n_bind
qf -= cs_mat[ii, k] / p.qmax[k]
end
for j in 1:n_bind
E_ij = exp((cp_mat[ii, n_cmp] - p.salt_ref) * p.gamma[j])
D_ij = (cp_mat[ii, n_cmp] / p.salt_ref)^p.beta[j]
ads_base = p.ka[j] * E_ij
# Contribution from v_cp
eff = ads_base * p.qmax[j] * qf * v_cp[(j - 1) * n_pts + ii]
eff += (ads_base * p.gamma[j] * cp_mat[ii, j] * p.qmax[j] * qf -
p.kd[j] * p.beta[j] * D_ij / cp_mat[ii, n_cmp] * cs_mat[ii, j]) *
v_cp[(n_cmp - 1) * n_pts + ii]
# Contribution from v_cs
common_cs = -ads_base * cp_mat[ii, j] * p.qmax[j]
for m in 1:n_bind
diso_dcs_m = common_cs / p.qmax[m]
if m == j
diso_dcs_m -= p.kd[j] * D_ij
end
eff += diso_dcs_m * v_cs[(m - 1) * n_pts + ii]
end
iso_effect_buf[(j - 1) * n_pts + ii] = eff
end
end
# solid_binding += iso_effect, particle_binding -= β_p * iso_effect
@inbounds for k in 1:n_bnd_pts
Jv[cs_s + k - 1] += iso_effect_buf[k]
Jv[cp_s + k - 1] -= β_p * iso_effect_buf[k]
end
end
return nothing
end
end
fun = ODE.ODEFunction(
(du, u, pp, t) -> rhs_jac_noalloc!(du, u, pp, jac_wo_iso, flow_rate, t, model,
inlet_buf, iso_buf, load_len, elute_len, load_conc, salt_eq, salt_start);
jac_prototype = jac0, vjp = vjp_fn, jvp = jvp_fn, jac = ode_jac,
vjp_p = vjp_p_fn)
prob = ODE.ODEProblem(fun, u0, tspan)
return prob, bind_params, save_idxs, t_stop
end
prob, bind_params, save_idxs, t_stop = setup_problem()
# =====================================================================
# Benchmark
# =====================================================================
println("=" ^ 70)
println("Test: FBDF + GaussAdjoint + analytical VJP + vjp_p (zero alloc)")
println("=" ^ 70)
alg = ODE.FBDF(autodiff = false, linsolve = UMFPACKFactorization())
sense_alg = SMS.GaussAdjoint(autojacvec = SMS.EnzymeVJP())
loss_fn = let prob = prob, save_idxs = save_idxs, sense_alg = sense_alg, t_stop = t_stop, alg = alg
function (p)
new_prob = ODE.remake(prob, p = p)
sol = ODE.solve(new_prob, alg,
abstol = 1e-8, reltol = 1e-6, tstops = t_stop, saveat = 1.0, maxiters = 1e6, sensealg = sense_alg)
L = 0.0
for i in eachindex(sol.t)
@views L += sum(sol.u[i][save_idxs[1:(end - 1)]])
end
return L
end
end
println("\nForward solve:")
@time loss_val = loss_fn(bind_params)
@time loss_val = loss_fn(bind_params)
println(" Loss value: $loss_val")
println("\nGradient — 1st call (compile):")
@time grad1 = Zygote.gradient(loss_fn, bind_params)
println("\nGradient — 2nd call:")
@time grad2 = Zygote.gradient(loss_fn, bind_params)
println("\nGradient — 3rd call:")
@time grad3 = Zygote.gradient(loss_fn, bind_params)
println(" Gradient: $(grad3[1])")
# Quick correctness check vs ForwardDiff-based VJP
println("\n--- Correctness check ---")
expected = (ka = [-148.98160180462295, 4.682207747927423, -272.7606465430993],
kd = [1839.0723272591213, -9.160074574142724, 883.2396698989843],
qmax = [-0.2185663494129382, -0.0016972984600429248, -0.13072785804121395],
gamma = [-2.716768296358113, 0.0697274702316863, -3.8895439826353875],
beta = [2.383488922767005, -0.07674961659760442, 3.325315684679378],
salt_ref = -20.062437960598114)
max_rel = 0.0
for field in fieldnames(typeof(expected))
v_exp = getfield(expected, field)
v_got = getproperty(grad3[1], field)
rel = maximum(abs.(v_exp .- v_got) ./ max.(abs.(v_exp), 1e-15))
global max_rel = max(max_rel, rel)
end
println(" Max relative difference vs reference: $max_rel")
if max_rel < 1e-4
println(" PASSED ✓")
else
println(" FAILED ✗")
end