Finding mutations that Zygote complains about

Hey,

Apparently, I have some hidden mutations in my code that I cannot find. Unfortunately, Zygote easily finds them and complains about it.
I get call stacks 40+ layers deep, the last of my own functions being near the top (layer 6):

  [1] error(s::String)
    @ Base ./error.jl:44
  [2] _throw_mutation_error(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/lib/array.jl:70
  [3] (::Zygote.var"#713#714"{Vector{Float64}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/lib/array.jl:85
  [4] (::Zygote.var"#715#716"{Zygote.var"#713#714"{Vector{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
  [5] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:100
  [6] rhs_jac!
    @ ~/Projects/NODE/src/Model.jl:441 [inlined]
  [7] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
  [8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:100
  [9] (::Zygote.Pullback{Tuple{…}, Any})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:100
 [10] #_vecjacobian!##16
    @ ~/.julia/packages/SciMLSensitivity/3s0Zy/src/derivative_wrappers.jl:659 [inlined]
 [11] (::Zygote.Pullback{Tuple{…}, Any})(Δ::SubArray{Float64, 1, Vector{…}, Tuple{…}, true})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#pullback##0#pullback##1"{Zygote.Pullback{…}})(Δ::SubArray{Float64, 1, Vector{…}, Tuple{…}, true})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:97
 [13] _vecjacobian!(dλ::SubArray{…}, y::Vector{…}, λ::SubArray{…}, p::ComponentArrays.ComponentVector{…}, t::Float64, S::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…}, isautojacvec::SciMLSensitivity.ZygoteVJP, dgrad::SubArray{…}, dy::Nothing, W::Nothing)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/3s0Zy/src/derivative_wrappers.jl:630
 [14] #vecjacobian!#19
    @ ~/.julia/packages/SciMLSensitivity/3s0Zy/src/derivative_wrappers.jl:257 [inlined]
 [15] vecjacobian!
    @ ~/.julia/packages/SciMLSensitivity/3s0Zy/src/derivative_wrappers.jl:252 [inlined]
 [16] (::SciMLSensitivity.ODEInterpolatingAdjointSensitivityFunction{…})(du::Vector{…}, u::Vector{…}, p::ComponentArrays.ComponentVector{…}, t::Float64)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/3s0Zy/src/interpolating_adjoint.jl:153
...

Background: This is part of a Neural ODE project, where I try to use adjoints for gradient computation.

How can I systematically debug / work through the code to find those mutations?

Thanks,
Neodym

It’s this function right here. You can pull it out and directly do the vjp call in order to recreate it outside of the SciMLSensitivity context, and that’ll make it easier to debug.

But, I’d recommend just trying MooncakeVJP or EnzymeVJP these days. That should just be faster too. Are you directly setting ZygoteVJP?

I’ve tried MooncakeVJP, but this gives me a stacktrace without any of my functions after the adjoint has started:

ERROR: MethodError: no method matching length(::Mooncake.Tangent{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoTangent}})
The function `length` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  length(::Compiler.DFSTree)
   @ Base ../usr/share/julia/Compiler/src/ssair/domtree.jl:121
  length(::Distributions.VonMisesFisherSampler)
   @ Distributions ~/.julia/packages/Distributions/xMnxM/src/samplers/vonmisesfisher.jl:34
  length(::Markdown.MD)
   @ Markdown ~/.julia/juliaup/julia-1.12.5+0.aarch64.apple.darwin14/Julia-1.12.app/Contents/Resources/julia/share/julia/stdlib/v1.12/Markdown/src/parse/parse.jl:35
  ...

Stacktrace:
  [1] _similar_shape(itr::Mooncake.Tangent{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoTangent}}, ::Base.HasLength)
    @ Base ./array.jl:657
  [2] _collect(cont::UnitRange{Int64}, itr::Mooncake.Tangent{@NamedTuple{…}}, ::Base.HasEltype, isz::Base.HasLength)
    @ Base ./array.jl:734
  [3] collect(itr::Mooncake.Tangent{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoTangent}})
    @ Base ./array.jl:728
  [4] broadcastable(x::Mooncake.Tangent{@NamedTuple{data::Vector{Float64}, axes::Mooncake.NoTangent}})
    @ Base.Broadcast ./broadcast.jl:733
  [5] broadcasted
    @ ./broadcast.jl:1345 [inlined]
  [6] vec_pjac!(out::ComponentVector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.AdjointSensitivityIntegrand{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/3s0Zy/src/quadrature_adjoint.jl:332
  [7] (::SciMLSensitivity.AdjointSensitivityIntegrand{…})(t::Float64)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/3s0Zy/src/quadrature_adjoint.jl:399
  [8] evalrule(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, a::Float64, b::Float64, x::Vector{…}, w::Vector{…}, wg::Vector{…}, nrm::typeof(LinearAlgebra.norm))
    @ QuadGK ~/.julia/packages/QuadGK/7rND3/src/evalrule.jl:25
  [9] #do_quadgk##4
    @ ~/.julia/packages/QuadGK/7rND3/src/adapt.jl:54 [inlined]
 [10] ntuple
    @ ./ntuple.jl:50 [inlined]
 [11] do_quadgk(f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…}, n::Int64, atol::Float64, rtol::Float64, maxevals::Int64, nrm::typeof(LinearAlgebra.norm), _segbuf::Nothing, eval_segbuf::Nothing)
    @ QuadGK ~/.julia/packages/QuadGK/7rND3/src/adapt.jl:52
 [12] #28
    @ ~/.julia/packages/QuadGK/7rND3/src/api.jl:83 [inlined]
 [13] handle_infinities(workfunc::QuadGK.var"#28#29"{…}, f::SciMLSensitivity.AdjointSensitivityIntegrand{…}, s::Tuple{…})
    @ QuadGK ~/.julia/packages/QuadGK/7rND3/src/adapt.jl:189
 [14] #quadgk#26
    @ ~/.julia/packages/QuadGK/7rND3/src/api.jl:82 [inlined]
 [15] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::SciMLSensitivity.QuadratureAdjoint{…}, alg::OrdinaryDiffEqBDF.QNDF{…}; t::StepRangeLen{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, no_start::Bool, abstol::Float64, reltol::Float64, callback::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/3s0Zy/src/quadrature_adjoint.jl:477
...

After fixing a Union type, EnzymeVJP surprisingly works.
Thanks for the suggestion!

However, the sensitivities take quite a long time. Actually, this is a spatial discretization of a PDE, which results in a Neural ODE. Whereas I can do the forward simulation in 0.2 - 0.6 s (depending on the discretization level), QuadratureAdjoint with EnzymeVJP needs 90 s for just 16 parameters on the the coarsest discretization.

Does Enzyme exploit the sparsity of the PDE discretization? In my case, a lot of the Jacobian is unaffected by the state (the spatial scheme is mostly linear). I’ve tried providing a custom vjp in the ODEFunction, but it seems to be ignored.

Do you have ideas / suggestions to speed this up?

Thanks,
Neodym

It should. Can I get some code to look at?

The Mooncake error looks like something that could be fixed with the new “friendly tangents” mode introduced in v0.5?

What is this new “friendly tangents” mode?

This is a (highly) simplified version of the code I’m using, but it is still slow.

# ] add OrdinaryDiffEq, SciMLSensitivity, LinearAlgebra, PolynomialBases, ComponentArrays, SparseArrays, ForwardDiff, Zygote, Mooncake, Enzyme, LinearSolve
import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
using LinearAlgebra, LinearSolve
using PolynomialBases, ComponentArrays, SparseArrays
import ForwardDiff
import Zygote

function legendre_inv_mass(b::NodalBasis{T}) where {T}
    V = legendre_vandermonde(b)
    for n in axes(V, 1)
        V[:, n] .= @view(V[:, n]) * sqrt((2*n - 1) / 2)
    end
    return V * V'
end

struct AxialFlowModel{BindModel}
    n_comp::Int
    cross_section_area::Float64
    length::Float64
    film_diffusion_coeff::Vector{Float64}
    col_porosity::Float64
    par_porosity::Float64
    par_radius::Float64
    binding_model::BindModel

    n_elements::Int
    n_points::Int
    n_dof_per_element::Int
    n_dof_per_phase::Int
    lgl_basis::LobattoLegendre{Float64}
    D::Array{Float64, 2}
    M_inv::Array{Float64, 2}
end

struct ModulatedBinding
    nComp::Int
    nBindingComps::Int
end

function binding_flux(model::ModulatedBinding, cp::cpType, cs::csType, p::pType) where {cpType, csType, pType}
    # cp, cs = [points, comp]
    # res = [points, bind_comp]
    @views begin
        q_free = 1.0 .- sum(cs[:, 1:model.nBindingComps] ./ reshape(p.qmax, 1, :), dims=2)
        ads = reshape(p.ka, 1, :) .* exp.((cp[:, end] .- p.salt_ref) .* reshape(p.gamma, 1, :)) .* cp[:, 1:model.nBindingComps] .* reshape(p.qmax, 1, :) .* q_free
        des = reshape(p.kd, 1, :) .* (cp[:, end] ./ p.salt_ref).^reshape(p.beta, 1, :) .* cs[:, 1:model.nBindingComps]
        return ads .- des
    end
end

function jac!(J::JType, model::ModulatedBinding, p::pType, row_offset, idx_start_liquid, idx_start_solid, cp, cs) where {JType, pType}
    # cp, cs = [points, comp]
    n_points = size(cp, 1)
    n_bind = model.nBindingComps
    n_comp = model.nComp
    ForwardDiff.jacobian!(view(J, row_offset:row_offset+n_bind*n_points-1, idx_start_liquid:idx_start_liquid+n_comp*n_points-1), cp -> reshape(binding_flux(model, reshape(cp, n_points, n_comp), cs, p), :), reshape(cp, :))
    ForwardDiff.jacobian!(view(J, row_offset:row_offset+n_bind*n_points-1, idx_start_solid:idx_start_solid+n_comp*n_points-1), cs -> reshape(binding_flux(model, cp, reshape(cs, n_points, n_comp), p), :), reshape(cs, :))
    nothing
end

function element_start_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + (element - 1) * disc.n_points + 1
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + (component - 1) * disc.n_points + 1
end

function element_end_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + element * disc.n_points
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + component * disc.n_points
end

function rhs_jac!(dc::dcType, 
            c::cType, 
            bind_params::pType, 
            jac_wo_isotherm::JacType,
            flow_rate::flowType,
            t::tType,
            model::AxialFlowModel{BindModel}, 
            inlet::Vector{Float64},
            )  where {pType, JacType, tType<:Real, dcType, cType, flowType, BindModel}

    Δz = model.length / model.n_elements
    two_over_Δz = 2.0 / Δz

    velocity = flow_rate / (model.cross_section_area * model.col_porosity)

    n_points = model.n_points
    n_elements = model.n_elements

    mul!(dc, jac_wo_isotherm, c)
    #dc .= jac_wo_isotherm * c
    #dc[:] = jac_wo_isotherm * c

    @views for comp = 1:model.n_comp
        idx_start = element_start_index(:liquid, 1, comp, model)
        idx_end = element_end_index(:liquid, 1, comp, model)
        dc[idx_start:idx_end] .+= two_over_Δz .* model.M_inv[:, 1] .* velocity .* inlet[comp]
    end

    β_p = (1.0 - model.par_porosity) / model.par_porosity
    for i = 1:n_elements
        idx_start_particle = element_start_index(:particle, i, 1, model)
        idx_end_particle = element_end_index(:particle, i, model.n_comp, model)
        idx_start_solid = element_start_index(:solid, i, 1, model)
        idx_end_solid = element_end_index(:solid, i, model.n_comp, model)

        cp_element = reshape(view(c, idx_start_particle:idx_end_particle), n_points, model.n_comp)
        cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)

        iso = binding_flux(model.binding_model, cp_element, cs_element, bind_params)

        for comp = 1:model.n_comp
            idx_start_comp_particle = element_start_index(:particle, i, comp, model) - 1
            idx_start_comp_solid = element_start_index(:solid, i, comp, model) - 1

            for j = 1:n_points
                loc_iso = comp > model.binding_model.nBindingComps ? 0.0 : iso[j, comp]

                dc[idx_start_comp_particle + j] -= β_p * loc_iso
                dc[idx_start_comp_solid + j] = loc_iso
            end
        end
    end
    nothing
end

function jac!(J::JType, c::cType, bind_params::pType, flow_rate, t, model::AxialFlowModel{BindModel}, with_isotherm::Bool) where {JType, cType, pType, BindModel}
    velocity = flow_rate / (model.cross_section_area * model.col_porosity)
    n_points = model.n_points
    n_elements = model.n_elements
    Δz = model.length / n_elements
    two_over_Δz = 2.0 / Δz

    β_c = (1.0 - model.col_porosity) / model.col_porosity
    β_p = (1.0 - model.par_porosity) / model.par_porosity

    J .= 0.0

    @views begin
        for comp = 1:model.n_comp
            bulk_pore_factor = β_c * 3.0 / model.par_radius * model.film_diffusion_coeff[comp]
            for i = 1:n_elements
                cur_elem_start = element_start_index(:liquid, i, comp, model)
                cur_elem_end = element_end_index(:liquid, i, comp, model)
                idx_start_cp = element_start_index(:particle, i, comp, model)

                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .= model.D
                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .*= two_over_Δz .* (-velocity)

                for j = 0:n_points-1
                    J[cur_elem_start + j, cur_elem_start + j] -= bulk_pore_factor
                    J[cur_elem_start + j, idx_start_cp + j] = bulk_pore_factor
                end

                J[cur_elem_start:cur_elem_end, cur_elem_start] .+= -velocity * two_over_Δz * model.M_inv[:, 1]
                J[cur_elem_start:cur_elem_end, cur_elem_end] .+= velocity * two_over_Δz * model.M_inv[:, end]

                if i > 1
                    J[cur_elem_start:cur_elem_end, element_end_index(:liquid, i-1, comp, model)] .+= velocity * two_over_Δz * model.M_inv[:, 1]
                end

                J[cur_elem_start:cur_elem_end, cur_elem_end] .-= velocity * two_over_Δz * model.M_inv[:, end]
            end
        end

        par_factor = 3.0 / (model.par_radius * model.par_porosity)
        for comp = 1:model.n_comp
            for i = 1:n_elements
                idx_start_comp_particle = element_start_index(:particle, i, comp, model)
                idx_start_comp_liquid = element_start_index(:liquid, i, comp, model)

                for j = 0:n_points-1
                    J[idx_start_comp_particle + j, idx_start_comp_particle + j] = -par_factor * model.film_diffusion_coeff[comp]
                    J[idx_start_comp_particle + j, idx_start_comp_liquid + j] = par_factor * model.film_diffusion_coeff[comp]
                end
            end
        end

        if with_isotherm
            for i = 1:n_elements
                idx_start_liquid = element_start_index(:particle, i, 1, model)
                idx_end_liquid = element_end_index(:particle, i, model.n_comp, model)
                idx_start_solid = element_start_index(:solid, i, 1, model)
                idx_end_solid = element_end_index(:solid, i, model.n_comp, model)

                c_element = reshape(view(c, idx_start_liquid:idx_end_liquid), n_points, model.n_comp)
                cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)

                jac!(J, model.binding_model, bind_params, idx_start_solid, idx_start_liquid, idx_start_solid, c_element, cs_element)
                J[idx_start_liquid:idx_end_liquid, :] .-= β_p * J[idx_start_solid:idx_end_solid, :]
            end
        end
    end
    nothing
end

function jac_update_iso!(J::JType, c::cType, bind_params::pType, model::AxialFlowModel{BindModel}) where {JType, cType, pType, BindModel}
    n_points = model.n_points
    β_p = (1.0 - model.par_porosity) / model.par_porosity
    @views begin
        par_factor = 3.0 / (model.par_radius * model.par_porosity)

        for i = 1:model.n_elements
            idx_start_liquid = element_start_index(:particle, i, 1, model)
            idx_end_liquid = element_end_index(:particle, i, model.n_comp, model)
            idx_start_solid = element_start_index(:solid, i, 1, model)
            idx_end_solid = element_end_index(:solid, i, model.n_comp, model)

            c_element = reshape(view(c, idx_start_liquid:idx_end_liquid), n_points, model.n_comp)
            cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)

            jac!(J, model.binding_model, bind_params, idx_start_solid, idx_start_liquid, idx_start_solid, c_element, cs_element)
            @views J[idx_start_liquid:idx_end_liquid, idx_start_liquid:idx_end_solid] .= -β_p .* J[idx_start_solid:idx_end_solid, idx_start_liquid:idx_end_solid]
            for comp = 1:model.n_comp
                idx_start_comp_particle = element_start_index(:particle, i, comp, model)
                for j = 0:n_points-1
                    J[idx_start_comp_particle + j, idx_start_comp_particle + j] -= par_factor * model.film_diffusion_coeff[comp]
                end
            end
        end
    end
    nothing
end

function init_state!(u0, conc_liquid, conc_solid, model::AxialFlowModel)
    n_elements = model.n_elements
    n_comp = model.n_comp

    for i = 1:n_elements
        for comp = 1:n_comp
            idx_start_liquid = element_start_index(:liquid, i, comp, model)
            idx_end_liquid = element_end_index(:liquid, i, comp, model)
            idx_start_particle = element_start_index(:particle, i, comp, model)
            idx_end_particle = element_end_index(:particle, i, comp, model)
            idx_start_solid = element_start_index(:solid, i, comp, model)
            idx_end_solid = element_end_index(:solid, i, comp, model)

            u0[idx_start_liquid:idx_end_liquid] .= conc_liquid[comp]
            u0[idx_start_particle:idx_end_particle] .= conc_liquid[comp]
            u0[idx_start_solid:idx_end_solid] .= conc_solid[comp]
        end
    end
    nothing
end

function loss_grad()
    n_elements = 3 # Should be >= 10
    n_degree = 1 # Should be >= 3

    bind_params = ComponentArray(
        ka = [4.0, 5.5, 3.0] .* 1e-2,
        kd = [3.2, 22.0, 9.3] .* 1e-3,
        qmax = [3.0, 2.0, 6.5] .* 10.0,
        gamma = [-1.0, -0.5, 0.2],
        beta = [0.91, 0.82, 1.3],
        salt_ref = 1.0
    )

    lgl_basis = LobattoLegendre(n_degree)
    M_inv = legendre_inv_mass(lgl_basis)
    model = AxialFlowModel(
        4, 1.0 / 0.37, 0.014, fill(6.9e-6, 4), 0.37, 0.75, 45e-6,
        ModulatedBinding(4, 3),
        n_elements, n_degree + 1, (n_degree + 1) * 4, n_elements * (n_degree + 1) * 4, lgl_basis, lgl_basis.D, M_inv
    )

    save_idxs = [element_end_index(:liquid, n_elements, i, model) for i in 1:model.n_comp]

    salt_eq = 1.0
    salt_start = 1.5
    load_len = 10.0
    elute_len = 1410.0
    load_conc = [1.0, 1.0, 1.0]
    t_stop = [load_len, load_len + elute_len]

    inlet = let salt_eq = salt_eq, salt_start = salt_start, load_len = load_len, elute_len = elute_len, load_conc = load_conc
        function(t)
            if t < load_len
                return [load_conc..., salt_eq]
            elseif t < load_len + elute_len
                return [0.0, 0.0, 0.0, salt_start]
            end
            return zeros(Float64, 4)
        end
    end

    tspan = (0.0, load_len + elute_len)
    num_dofs = model.n_points * model.n_elements * model.n_comp * 3
    u0 = zeros(Float64, num_dofs)
    init_state!(u0, [0.0, 0.0, 0.0, salt_eq], [0.0, 0.0, 0.0, salt_eq], model)

    flow_rate = 3.45 / 60 / 100

    jac_cache = spzeros(length(u0), length(u0))
    jac!(jac_cache, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, true)

    ode_jac = let model = model
        function(J, u, p, t)
            jac!(J, u, p, flow_rate, t, model, true)
            return nothing
        end
    end

    vjp = let jac_cache = jac_cache, model = model
        function(Jv, v, u, p, t)
            jac_update_iso!(jac_cache, u, p, model)
            Jv .= jac_cache' * v
            return nothing
        end
    end

    jvp = let jac_cache = jac_cache, model = model
        function(Jv, v, u, p, t)
            jac_update_iso!(jac_cache, u, p, model)
            Jv .= jac_cache * v
            return nothing
        end
    end

    jac_wo_iso = spzeros(length(u0), length(u0))
    jac!(jac_wo_iso, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, false)
    jac0 = spzeros(length(u0), length(u0))
    jac!(jac0, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, true)

    # Why is jac_prototype needed for the adjoint sensitivity analysis to work correctly?
    # We already supply vjp and the actual Jacobian
    fun = ODE.ODEFunction((du, u, pp, t) -> rhs_jac!(du, u, pp, jac_wo_iso, flow_rate, t, model, inlet(t)); jac_prototype=jac0, vjp=vjp, jvp=jvp, jac=ode_jac)
    prob = ODE.ODEProblem(fun, u0, tspan)

    sense_alg = SMS.QuadratureAdjoint(autojacvec=SMS.EnzymeVJP(), autodiff=true)

    loss = let prob = prob, save_idxs = save_idxs, sense_alg = sense_alg, t_stop = t_stop
        function(p)
            new_prob = ODE.remake(prob, p=p)
            sol = SMS.solve(new_prob, ODE.QNDF(autodiff=false, linsolve=UMFPACKFactorization()), abstol=1e-8, reltol=1e-6, tstops=t_stop, saveat=1.0, maxiters=1e6, sensealg = sense_alg)
            L = 0.0
            for i in eachindex(sol.t)
                @views L += sum(sol.u[i][save_idxs[1:end-1]])
            end
            return L
        end
    end

    return Zygote.gradient(loss, bind_params)
end

Entrypoint is loss_grad(), discretization is controlled by n_elements and n_degree.

The loss is fast:

@time loss(bind_params)
0.006014 seconds (39.96 k allocations: 19.951 MiB)

Gradient is not:

@time Zygote.gradient(loss, bind_params)
41.118779 seconds (325.92 M allocations: 42.292 GiB, 18.20% gc time)

Thanks for looking into this!

It’s a way for Mooncake to return derivatives that imitate the structure of the primal object, instead of being wrapped inside a Mooncake.Tangent.

what’s the time on a second run? First gradient will have compile time as part of it.

That was the second run :frowning: . First run is

@time loss(bind_params)
4.096212 seconds (10.05 M allocations: 525.757 MiB, 99.78% compilation time)

@time Zygote.gradient(loss, bind_params)
220.033784 seconds (671.59 M allocations: 58.943 GiB, 8.36% gc time, 70.30% compilation time: <1% of which was recompilation)

Ah alas, skimming your code quickly the amount of allocations and algebra makes me think that Reactant might help give you a boost.

@ChrisRackauckas how would one do that here, just use ReactantVJP?

It still needs a bit of work. Specifically, Reactant doesn’t support it directly:

so I have to work around it, which I am doing in:

So hopefully next week ReactantVJP will work on this.

But this case also has manual vjps, so I’m getting that ready in:

and am using this as the test case there.

FBDF is much faster here with the recent improvements. Update to the latest OrdinaryDiffEq and SciMLSensitivity and try again now with FBDF as the solver.

Update now on Reactant support is I think
Fix _parentsmatch for TracedRArray vs regular Array by ChrisRackauckas-Claude · Pull Request #2570 · EnzymeAD/Reactant.jl · GitHub is the last piece @wsmoses

PR is already approved, just waiting for CI :slight_smile:

With the new version of Reactant.jl and SciMLSensitivity.jl, the code can work with ReactantVJP. However, it needs a bit more vectorization, the code was very naturally scalar, and so:

# Test ReactantVJP specifically
# Requires: using Reactant loaded

import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
using LinearAlgebra, LinearSolve
using PolynomialBases, ComponentArrays, SparseArrays
import ForwardDiff
import Zygote
using Reactant

function legendre_inv_mass(b::NodalBasis{T}) where {T}
    V = legendre_vandermonde(b)
    for n in axes(V, 1)
        V[:, n] .= @view(V[:, n]) * sqrt((2*n - 1) / 2)
    end
    return V * V'
end

struct ModulatedBinding
    nComp::Int
    nBindingComps::Int
end

function binding_flux(model::ModulatedBinding, cp::cpType, cs::csType, p::pType) where {cpType, csType, pType}
    @views begin
        q_free = 1.0 .- sum(cs[:, 1:model.nBindingComps] ./ reshape(p.qmax, 1, :), dims=2)
        ads = reshape(p.ka, 1, :) .* exp.((cp[:, end] .- p.salt_ref) .* reshape(p.gamma, 1, :)) .* cp[:, 1:model.nBindingComps] .* reshape(p.qmax, 1, :) .* q_free
        des = reshape(p.kd, 1, :) .* (cp[:, end] ./ p.salt_ref).^reshape(p.beta, 1, :) .* cs[:, 1:model.nBindingComps]
        return ads .- des
    end
end

struct AxialFlowModel{BindModel}
    n_comp::Int
    cross_section_area::Float64
    length::Float64
    film_diffusion_coeff::Vector{Float64}
    col_porosity::Float64
    par_porosity::Float64
    par_radius::Float64
    binding_model::BindModel
    n_elements::Int
    n_points::Int
    n_dof_per_element::Int
    n_dof_per_phase::Int
    lgl_basis::LobattoLegendre{Float64}
    D::Array{Float64, 2}
    M_inv::Array{Float64, 2}
end

function element_start_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + (element - 1) * disc.n_points + 1
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + (component - 1) * disc.n_points + 1
end

function element_end_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + element * disc.n_points
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + component * disc.n_points
end

function rhs_jac!(dc::dcType, c::cType, bind_params::pType,
                jac_wo_isotherm::JacType, flow_rate::flowType, t::tType,
                model::AxialFlowModel{BindModel}, inlet::AbstractVector) where {pType, JacType, tType, dcType, cType, flowType, BindModel}
    Δz = model.length / model.n_elements
    two_over_Δz = 2.0 / Δz
    velocity = flow_rate / (model.cross_section_area * model.col_porosity)
    n_points = model.n_points
    n_elements = model.n_elements
    mul!(dc, jac_wo_isotherm, c)
    @views for comp = 1:model.n_comp
        idx_start = element_start_index(:liquid, 1, comp, model)
        idx_end = element_end_index(:liquid, 1, comp, model)
        dc[idx_start:idx_end] .+= two_over_Δz .* model.M_inv[:, 1] .* velocity .* inlet[comp:comp]
    end
    β_p = (1.0 - model.par_porosity) / model.par_porosity
    for i = 1:n_elements
        idx_start_particle = element_start_index(:particle, i, 1, model)
        idx_end_particle = element_end_index(:particle, i, model.n_comp, model)
        idx_start_solid = element_start_index(:solid, i, 1, model)
        idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
        cp_element = reshape(view(c, idx_start_particle:idx_end_particle), n_points, model.n_comp)
        cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)
        iso = binding_flux(model.binding_model, cp_element, cs_element, bind_params)
        for comp = 1:model.n_comp
            p_range = element_start_index(:particle, i, comp, model):element_end_index(:particle, i, comp, model)
            s_range = element_start_index(:solid, i, comp, model):element_end_index(:solid, i, comp, model)
            if comp <= model.binding_model.nBindingComps
                iso_col = @view iso[:, comp]
                dc[p_range] .-= β_p .* iso_col
                dc[s_range] .= iso_col
            else
                dc[s_range] .= 0.0
            end
        end
    end
    nothing
end

function jac!(J::JType, model::ModulatedBinding, p::pType, row_offset,
              idx_start_liquid, idx_start_solid, cp, cs) where {JType, pType}
    n_points = size(cp, 1)
    n_bind = model.nBindingComps
    n_comp = model.nComp
    ForwardDiff.jacobian!(view(J, row_offset:row_offset+n_bind*n_points-1,
        idx_start_liquid:idx_start_liquid+n_comp*n_points-1),
        cp -> reshape(binding_flux(model, reshape(cp, n_points, n_comp), cs, p), :),
        reshape(cp, :))
    ForwardDiff.jacobian!(view(J, row_offset:row_offset+n_bind*n_points-1,
        idx_start_solid:idx_start_solid+n_comp*n_points-1),
        cs -> reshape(binding_flux(model, cp, reshape(cs, n_points, n_comp), p), :),
        reshape(cs, :))
    nothing
end

function jac!(J::JType, c::cType, bind_params::pType, flow_rate, t,
              model::AxialFlowModel{BindModel}, with_isotherm::Bool) where {JType, cType, pType, BindModel}
    velocity = flow_rate / (model.cross_section_area * model.col_porosity)
    n_points = model.n_points
    n_elements = model.n_elements
    Δz = model.length / n_elements
    two_over_Δz = 2.0 / Δz
    β_c = (1.0 - model.col_porosity) / model.col_porosity
    β_p = (1.0 - model.par_porosity) / model.par_porosity
    J .= 0.0
    @views begin
        for comp = 1:model.n_comp
            bulk_pore_factor = β_c * 3.0 / model.par_radius * model.film_diffusion_coeff[comp]
            for i = 1:n_elements
                cur_elem_start = element_start_index(:liquid, i, comp, model)
                cur_elem_end = element_end_index(:liquid, i, comp, model)
                idx_start_cp = element_start_index(:particle, i, comp, model)
                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .= model.D
                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .*= two_over_Δz .* (-velocity)
                for j = 0:n_points-1
                    J[cur_elem_start + j, cur_elem_start + j] -= bulk_pore_factor
                    J[cur_elem_start + j, idx_start_cp + j] = bulk_pore_factor
                end
                J[cur_elem_start:cur_elem_end, cur_elem_start] .+= -velocity * two_over_Δz * model.M_inv[:, 1]
                J[cur_elem_start:cur_elem_end, cur_elem_end] .+= velocity * two_over_Δz * model.M_inv[:, end]
                if i > 1
                    J[cur_elem_start:cur_elem_end, element_end_index(:liquid, i-1, comp, model)] .+= velocity * two_over_Δz * model.M_inv[:, 1]
                end
                J[cur_elem_start:cur_elem_end, cur_elem_end] .-= velocity * two_over_Δz * model.M_inv[:, end]
            end
        end
        par_factor = 3.0 / (model.par_radius * model.par_porosity)
        for comp = 1:model.n_comp
            for i = 1:n_elements
                idx_start_comp_particle = element_start_index(:particle, i, comp, model)
                idx_start_comp_liquid = element_start_index(:liquid, i, comp, model)
                for j = 0:n_points-1
                    J[idx_start_comp_particle + j, idx_start_comp_particle + j] = -par_factor * model.film_diffusion_coeff[comp]
                    J[idx_start_comp_particle + j, idx_start_comp_liquid + j] = par_factor * model.film_diffusion_coeff[comp]
                end
            end
        end
        if with_isotherm
            for i = 1:n_elements
                idx_start_liquid = element_start_index(:particle, i, 1, model)
                idx_end_liquid = element_end_index(:particle, i, model.n_comp, model)
                idx_start_solid = element_start_index(:solid, i, 1, model)
                idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
                c_element = reshape(view(c, idx_start_liquid:idx_end_liquid), n_points, model.n_comp)
                cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)
                jac!(J, model.binding_model, bind_params, idx_start_solid, idx_start_liquid, idx_start_solid, c_element, cs_element)
                J[idx_start_liquid:idx_end_liquid, :] .-= β_p * J[idx_start_solid:idx_end_solid, :]
            end
        end
    end
    nothing
end

function jac_update_iso!(J::JType, c::cType, bind_params::pType,
                         model::AxialFlowModel{BindModel}) where {JType, cType, pType, BindModel}
    n_points = model.n_points
    β_p = (1.0 - model.par_porosity) / model.par_porosity
    @views begin
        par_factor = 3.0 / (model.par_radius * model.par_porosity)
        for i = 1:model.n_elements
            idx_start_liquid = element_start_index(:particle, i, 1, model)
            idx_end_liquid = element_end_index(:particle, i, model.n_comp, model)
            idx_start_solid = element_start_index(:solid, i, 1, model)
            idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
            c_element = reshape(view(c, idx_start_liquid:idx_end_liquid), n_points, model.n_comp)
            cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)
            jac!(J, model.binding_model, bind_params, idx_start_solid, idx_start_liquid, idx_start_solid, c_element, cs_element)
            @views J[idx_start_liquid:idx_end_liquid, idx_start_liquid:idx_end_solid] .= -β_p .* J[idx_start_solid:idx_end_solid, idx_start_liquid:idx_end_solid]
            for comp = 1:model.n_comp
                idx_start_comp_particle = element_start_index(:particle, i, comp, model)
                for j = 0:n_points-1
                    J[idx_start_comp_particle + j, idx_start_comp_particle + j] -= par_factor * model.film_diffusion_coeff[comp]
                end
            end
        end
    end
    nothing
end

function init_state!(u0, conc_liquid, conc_solid, model::AxialFlowModel)
    n_elements = model.n_elements
    n_comp = model.n_comp
    for i = 1:n_elements
        for comp = 1:n_comp
            idx_start_liquid = element_start_index(:liquid, i, comp, model)
            idx_end_liquid = element_end_index(:liquid, i, comp, model)
            idx_start_particle = element_start_index(:particle, i, comp, model)
            idx_end_particle = element_end_index(:particle, i, comp, model)
            idx_start_solid = element_start_index(:solid, i, comp, model)
            idx_end_solid = element_end_index(:solid, i, comp, model)
            u0[idx_start_liquid:idx_end_liquid] .= conc_liquid[comp]
            u0[idx_start_particle:idx_end_particle] .= conc_liquid[comp]
            u0[idx_start_solid:idx_end_solid] .= conc_solid[comp]
        end
    end
    nothing
end

function setup_problem()
    n_elements = 3
    n_degree = 1
    bind_params = ComponentArray(
        ka = [4.0, 5.5, 3.0] .* 1e-2,
        kd = [3.2, 22.0, 9.3] .* 1e-3,
        qmax = [3.0, 2.0, 6.5] .* 10.0,
        gamma = [-1.0, -0.5, 0.2],
        beta = [0.91, 0.82, 1.3],
        salt_ref = 1.0
    )
    lgl_basis = LobattoLegendre(n_degree)
    M_inv = legendre_inv_mass(lgl_basis)
    model = AxialFlowModel(
        4, 1.0 / 0.37, 0.014, fill(6.9e-6, 4), 0.37, 0.75, 45e-6,
        ModulatedBinding(4, 3),
        n_elements, n_degree + 1, (n_degree + 1) * 4, n_elements * (n_degree + 1) * 4, lgl_basis, lgl_basis.D, M_inv
    )
    save_idxs = [element_end_index(:liquid, n_elements, i, model) for i in 1:model.n_comp]
    salt_eq = 1.0
    salt_start = 1.5
    load_len = 10.0
    elute_len = 1410.0
    load_conc = [1.0, 1.0, 1.0]
    t_stop = [load_len, load_len + elute_len]
    inlet = let salt_eq = salt_eq, salt_start = salt_start, load_len = load_len,
                elute_len = elute_len, load_conc = load_conc
        function(t)
            in_load = t < load_len
            in_elute = t < load_len + elute_len
            load_vals = [load_conc..., salt_eq]
            elute_vals = [0.0, 0.0, 0.0, salt_start]
            zero_vals = zeros(Float64, 4)
            # Nest ifelse: if in_load → load_vals, else if in_elute → elute_vals, else zero_vals
            return ifelse.(in_load, load_vals, ifelse.(in_elute, elute_vals, zero_vals))
        end
    end
    tspan = (0.0, load_len + elute_len)
    num_dofs = model.n_points * model.n_elements * model.n_comp * 3
    u0 = zeros(Float64, num_dofs)
    init_state!(u0, [0.0, 0.0, 0.0, salt_eq], [0.0, 0.0, 0.0, salt_eq], model)
    flow_rate = 3.45 / 60 / 100
    jac_cache = spzeros(length(u0), length(u0))
    jac!(jac_cache, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, true)
    ode_jac = let model = model
        function(J, u, p, t)
            jac!(J, u, p, flow_rate, t, model, true)
            return nothing
        end
    end
    vjp_fn = let jac_cache = jac_cache, model = model
        function(Jv, v, u, p, t)
            jac_update_iso!(jac_cache, u, p, model)
            Jv .= jac_cache' * v
            return nothing
        end
    end
    jvp_fn = let jac_cache = jac_cache, model = model
        function(Jv, v, u, p, t)
            jac_update_iso!(jac_cache, u, p, model)
            Jv .= jac_cache * v
            return nothing
        end
    end
    jac_wo_iso = spzeros(length(u0), length(u0))
    jac!(jac_wo_iso, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, false)
    jac0 = spzeros(length(u0), length(u0))
    jac!(jac0, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, true)
    fun = ODE.ODEFunction((du, u, pp, t) -> rhs_jac!(du, u, pp, jac_wo_iso, flow_rate, t, model, inlet(t));
                          jac_prototype=jac0, vjp=vjp_fn, jvp=jvp_fn, jac=ode_jac)
    prob = ODE.ODEProblem(fun, u0, tspan)
    return prob, bind_params, save_idxs, t_stop
end

prob, bind_params, save_idxs, t_stop = setup_problem()

# =====================================================================
# Test ReactantVJP with FBDF
# =====================================================================
println("=" ^ 70)
println("Test: FBDF + ReactantVJP (QuadratureAdjoint)")
println("=" ^ 70)

sense_alg_reactant = SMS.QuadratureAdjoint(autojacvec=SMS.ReactantVJP(allow_scalar=true), autodiff=true)

loss_reactant = let prob = prob, save_idxs = save_idxs, sense_alg = sense_alg_reactant, t_stop = t_stop
    function(p)
        new_prob = ODE.remake(prob, p=p)
        sol = ODE.solve(new_prob, ODE.FBDF(autodiff=false, linsolve=UMFPACKFactorization()),
                       abstol=1e-8, reltol=1e-6, tstops=t_stop, saveat=1.0, maxiters=1e6, sensealg=sense_alg)
        L = 0.0
        for i in eachindex(sol.t)
            @views L += sum(sol.u[i][save_idxs[1:end-1]])
        end
        return L
    end
end

println("\nForward solve (FBDF):")
@time loss_val = loss_reactant(bind_params)
@time loss_val = loss_reactant(bind_params)
println("  Loss value: $loss_val")

println("\nGradient (FBDF + ReactantVJP) — first call (includes compilation):")
@time grad_reactant = Zygote.gradient(loss_reactant, bind_params)
println("\nGradient (FBDF + ReactantVJP) — second call:")
@time grad_reactant = Zygote.gradient(loss_reactant, bind_params)
println("  Gradient: $(grad_reactant[1])")

Which results in:

Configuration Time Allocations Memory
ReactantVJP + FBDF 21.8s 28.68M 7.1 GiB
EnzymeVJP + FBDF 12.8s 28.78M 7.4 GiB
EnzymeVJP + QNDF (original Discourse) 95.2s 175.59M 37.9 GiB

So the FBDF change with EnzymeVJP is the winner.

However, I suspect this is because Reactant has its own JIT, and so we probably need to cache this JIT’d kernel in order to optimize it… let me see how much that matters.

Arguably Reactant.jl should be doing this:

but that only brings us to 19 seconds, so pure Enzyme is still faster here.

Played around with the agent a bit, seems I got this down to 4 seconds now with the latest versions if I use the new vjp hooks:

Config Time Allocs GiB GC%
OLD sparse VJP + EnzymeVJP 12.2s 31M 7.0 6%
Dense-block VJP + ForwardDiff + vjp_p 6.5s 26M 7.5 12%
Analytical VJP + vjp_p 4.0s 9.7M 4.4 10%

Fastest code:

# Fully non-allocating VJP with analytical derivatives
# No ForwardDiff in the VJP hot path at all.
# Computes J^T*v and (∂f/∂p)^T*v directly via hand-coded derivatives of binding_flux.

import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
using LinearAlgebra, LinearSolve
using PolynomialBases, ComponentArrays, SparseArrays
import ForwardDiff
import Zygote

function legendre_inv_mass(b::NodalBasis{T}) where {T}
    V = legendre_vandermonde(b)
    for n in axes(V, 1)
        V[:, n] .= @view(V[:, n]) * sqrt((2 * n - 1) / 2)
    end
    return V * V'
end

struct ModulatedBinding
    nComp::Int
    nBindingComps::Int
end

function binding_flux!(iso, model::ModulatedBinding, cp, cs, p)
    n_points = size(cp, 1)
    n_bind = model.nBindingComps
    @inbounds for j in 1:n_bind
        for i in 1:n_points
            q_free = 1.0
            for k in 1:n_bind
                q_free -= cs[i, k] / p.qmax[k]
            end
            ads = p.ka[j] * exp((cp[i, end] - p.salt_ref) * p.gamma[j]) *
                  cp[i, j] * p.qmax[j] * q_free
            des = p.kd[j] * (cp[i, end] / p.salt_ref)^p.beta[j] * cs[i, j]
            iso[i, j] = ads - des
        end
    end
    return nothing
end

# Allocating version only used for Jacobian prototype computation (setup, not hot path)
function binding_flux(model::ModulatedBinding, cp::cpType, cs::csType, p::pType) where {cpType, csType, pType}
    @views begin
        q_free = 1.0 .- sum(cs[:, 1:model.nBindingComps] ./ reshape(p.qmax, 1, :), dims = 2)
        ads = reshape(p.ka, 1, :) .* exp.((cp[:, end] .- p.salt_ref) .* reshape(p.gamma, 1, :)) .* cp[:, 1:model.nBindingComps] .* reshape(p.qmax, 1, :) .* q_free
        des = reshape(p.kd, 1, :) .* (cp[:, end] ./ p.salt_ref) .^ reshape(p.beta, 1, :) .* cs[:, 1:model.nBindingComps]
        return ads .- des
    end
end

# =========================================================================
# Analytical VJP of binding_flux: computes J_cp^T * w and J_cs^T * w
# directly without forming the Jacobian matrix.
#
# iso[i,j] = ka[j]*E[i,j]*cp[i,j]*qmax[j]*qf[i] - kd[j]*D[i,j]*cs[i,j]
# where E[i,j] = exp((cp[i,end]-salt_ref)*gamma[j])
#       D[i,j] = (cp[i,end]/salt_ref)^beta[j]
#       qf[i]  = 1 - Σ_k cs[i,k]/qmax[k]
#
# Writes into Jv_cp (n_comp*n_points) and Jv_cs (n_comp*n_points).
# w is (n_bind*n_points) = the combined weight vector.
# =========================================================================
function binding_flux_vjp_state!(Jv_cp, Jv_cs, w, model::ModulatedBinding, cp, cs, p)
    n_pts = size(cp, 1)
    n_bind = model.nBindingComps
    n_comp = model.nComp

    fill!(Jv_cp, 0.0)
    fill!(Jv_cs, 0.0)

    @inbounds for i in 1:n_pts
        # Precompute q_free for this point
        qf = 1.0
        for k in 1:n_bind
            qf -= cs[i, k] / p.qmax[k]
        end

        for j in 1:n_bind
            w_ij = w[(j - 1) * n_pts + i]  # weight for output iso[i,j]

            E_ij = exp((cp[i, n_comp] - p.salt_ref) * p.gamma[j])
            D_ij = (cp[i, n_comp] / p.salt_ref)^p.beta[j]

            ads_base = p.ka[j] * E_ij  # common factor for adsorption terms

            # ∂iso[i,j]/∂cp[i,j] = ka[j]*E[i,j]*qmax[j]*qf[i]
            Jv_cp[(j - 1) * n_pts + i] += ads_base * p.qmax[j] * qf * w_ij

            # ∂iso[i,j]/∂cp[i,end] = ka[j]*gamma[j]*E[i,j]*cp[i,j]*qmax[j]*qf[i]
            #                       - kd[j]*beta[j]*(cp[i,end]/salt_ref)^(beta[j]-1)/salt_ref*cs[i,j]
            diso_dcpend = ads_base * p.gamma[j] * cp[i, j] * p.qmax[j] * qf -
                          p.kd[j] * p.beta[j] * D_ij / cp[i, n_comp] * cs[i, j]
            Jv_cp[(n_comp - 1) * n_pts + i] += diso_dcpend * w_ij

            # ∂iso[i,j]/∂cs[i,m]:
            #   for all binding m: -ka[j]*E[i,j]*cp[i,j]*qmax[j]/qmax[m]
            #   additionally if m==j: -kd[j]*D[i,j]
            common_cs = -ads_base * cp[i, j] * p.qmax[j]
            for m in 1:n_bind
                diso_dcs_m = common_cs / p.qmax[m]
                if m == j
                    diso_dcs_m -= p.kd[j] * D_ij
                end
                Jv_cs[(m - 1) * n_pts + i] += diso_dcs_m * w_ij
            end
        end
    end
    return nothing
end

# =========================================================================
# Analytical VJP of binding_flux w.r.t. parameters: computes J_p^T * w
# directly without forming the Jacobian matrix.
#
# Parameters: ka(n_bind), kd(n_bind), qmax(n_bind), gamma(n_bind), beta(n_bind), salt_ref(1)
# =========================================================================
function binding_flux_vjp_params!(Jv_p, w, model::ModulatedBinding, cp, cs, p)
    n_pts = size(cp, 1)
    n_bind = model.nBindingComps
    n_comp = model.nComp

    fill!(Jv_p, 0.0)

    @inbounds for i in 1:n_pts
        qf = 1.0
        for k in 1:n_bind
            qf -= cs[i, k] / p.qmax[k]
        end

        for j in 1:n_bind
            w_ij = w[(j - 1) * n_pts + i]

            E_ij = exp((cp[i, n_comp] - p.salt_ref) * p.gamma[j])
            D_ij = (cp[i, n_comp] / p.salt_ref)^p.beta[j]
            ads_full = p.ka[j] * E_ij * cp[i, j] * p.qmax[j] * qf

            # ∂iso/∂ka[j] = E[i,j]*cp[i,j]*qmax[j]*qf[i]
            Jv_p.ka[j] += (ads_full / p.ka[j]) * w_ij

            # ∂iso/∂kd[j] = -D[i,j]*cs[i,j]
            Jv_p.kd[j] += (-D_ij * cs[i, j]) * w_ij

            # ∂iso/∂qmax[m]:
            #   if m==j: ka[j]*E*cp[i,j]*(qf + cs[i,j]/qmax[j])
            #   for all binding m: ka[j]*E*cp[i,j]*qmax[j]*cs[i,m]/qmax[m]^2
            for m in 1:n_bind
                diso_dqmax_m = p.ka[j] * E_ij * cp[i, j] * p.qmax[j] * cs[i, m] / p.qmax[m]^2
                if m == j
                    diso_dqmax_m += p.ka[j] * E_ij * cp[i, j] * qf
                end
                Jv_p.qmax[m] += diso_dqmax_m * w_ij
            end

            # ∂iso/∂gamma[j] = ka[j]*(cp[i,end]-salt_ref)*E[i,j]*cp[i,j]*qmax[j]*qf[i]
            Jv_p.gamma[j] += ads_full * (cp[i, n_comp] - p.salt_ref) * w_ij

            # ∂iso/∂beta[j] = -kd[j]*log(cp[i,end]/salt_ref)*D[i,j]*cs[i,j]
            Jv_p.beta[j] += (-p.kd[j] * log(cp[i, n_comp] / p.salt_ref) * D_ij * cs[i, j]) * w_ij

            # ∂iso/∂salt_ref:
            #   -ka[j]*gamma[j]*E[i,j]*cp[i,j]*qmax[j]*qf[i]
            #   +kd[j]*beta[j]*D[i,j]/salt_ref*cs[i,j]
            Jv_p.salt_ref += (-ads_full * p.gamma[j] +
                              p.kd[j] * p.beta[j] * D_ij / p.salt_ref * cs[i, j]) * w_ij
        end
    end
    return nothing
end

struct AxialFlowModel{BindModel}
    n_comp::Int
    cross_section_area::Float64
    length::Float64
    film_diffusion_coeff::Vector{Float64}
    col_porosity::Float64
    par_porosity::Float64
    par_radius::Float64
    binding_model::BindModel
    n_elements::Int
    n_points::Int
    n_dof_per_element::Int
    n_dof_per_phase::Int
    lgl_basis::LobattoLegendre{Float64}
    D::Array{Float64, 2}
    M_inv::Array{Float64, 2}
end

function element_start_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + (element - 1) * disc.n_points + 1
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + (component - 1) * disc.n_points + 1
end

function element_end_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + element * disc.n_points
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + component * disc.n_points
end

function rhs_jac_noalloc!(dc, c, bind_params, jac_wo_isotherm, flow_rate, t,
        model::AxialFlowModel{BindModel}, inlet_buf, iso_buf,
        load_len, elute_len, load_conc, salt_eq, salt_start) where {BindModel}
    if t < load_len
        for i in 1:length(load_conc)
            inlet_buf[i] = load_conc[i]
        end
        inlet_buf[end] = salt_eq
    elseif t < load_len + elute_len
        for i in 1:length(load_conc)
            inlet_buf[i] = 0.0
        end
        inlet_buf[end] = salt_start
    else
        fill!(inlet_buf, 0.0)
    end
    Δz = model.length / model.n_elements
    two_over_Δz = 2.0 / Δz
    velocity = flow_rate / (model.cross_section_area * model.col_porosity)
    mul!(dc, jac_wo_isotherm, c)
    @views for comp in 1:model.n_comp
        idx_start = element_start_index(:liquid, 1, comp, model)
        idx_end = element_end_index(:liquid, 1, comp, model)
        inlet_val = inlet_buf[comp]
        for j in idx_start:idx_end
            dc[j] += two_over_Δz * model.M_inv[j - idx_start + 1, 1] * velocity * inlet_val
        end
    end
    β_p = (1.0 - model.par_porosity) / model.par_porosity
    for i in 1:model.n_elements
        idx_start_particle = element_start_index(:particle, i, 1, model)
        idx_end_particle = element_end_index(:particle, i, model.n_comp, model)
        idx_start_solid = element_start_index(:solid, i, 1, model)
        idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
        cp_element = reshape(view(c, idx_start_particle:idx_end_particle), model.n_points, model.n_comp)
        cs_element = reshape(view(c, idx_start_solid:idx_end_solid), model.n_points, model.n_comp)
        binding_flux!(iso_buf, model.binding_model, cp_element, cs_element, bind_params)
        for comp in 1:model.n_comp
            p_start = element_start_index(:particle, i, comp, model)
            s_start = element_start_index(:solid, i, comp, model)
            if comp <= model.binding_model.nBindingComps
                @inbounds for j in 1:model.n_points
                    dc[p_start + j - 1] -= β_p * iso_buf[j, comp]
                    dc[s_start + j - 1] = iso_buf[j, comp]
                end
            else
                @inbounds for j in 1:model.n_points
                    dc[s_start + j - 1] = 0.0
                end
            end
        end
    end
    return nothing
end

# Full Jacobian (only used for setup / jac_prototype, not hot path)
function jac!(J::JType, model::ModulatedBinding, p::pType, row_offset,
        idx_start_liquid, idx_start_solid, cp, cs) where {JType, pType}
    n_points = size(cp, 1)
    n_bind = model.nBindingComps
    n_comp = model.nComp
    ForwardDiff.jacobian!(view(J, row_offset:(row_offset + n_bind * n_points - 1),
            idx_start_liquid:(idx_start_liquid + n_comp * n_points - 1)),
        cp -> reshape(binding_flux(model, reshape(cp, n_points, n_comp), cs, p), :),
        reshape(cp, :))
    ForwardDiff.jacobian!(view(J, row_offset:(row_offset + n_bind * n_points - 1),
            idx_start_solid:(idx_start_solid + n_comp * n_points - 1)),
        cs -> reshape(binding_flux(model, cp, reshape(cs, n_points, n_comp), p), :),
        reshape(cs, :))
    return nothing
end

function jac!(J::JType, c::cType, bind_params::pType, flow_rate, t,
        model::AxialFlowModel{BindModel}, with_isotherm::Bool) where {JType, cType, pType, BindModel}
    velocity = flow_rate / (model.cross_section_area * model.col_porosity)
    n_points = model.n_points
    n_elements = model.n_elements
    Δz = model.length / n_elements
    two_over_Δz = 2.0 / Δz
    β_c = (1.0 - model.col_porosity) / model.col_porosity
    β_p = (1.0 - model.par_porosity) / model.par_porosity
    J .= 0.0
    @views begin
        for comp in 1:model.n_comp
            bulk_pore_factor = β_c * 3.0 / model.par_radius * model.film_diffusion_coeff[comp]
            for i in 1:n_elements
                cur_elem_start = element_start_index(:liquid, i, comp, model)
                cur_elem_end = element_end_index(:liquid, i, comp, model)
                idx_start_cp = element_start_index(:particle, i, comp, model)
                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .= model.D
                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .*= two_over_Δz .* (-velocity)
                for j in 0:(n_points - 1)
                    J[cur_elem_start + j, cur_elem_start + j] -= bulk_pore_factor
                    J[cur_elem_start + j, idx_start_cp + j] = bulk_pore_factor
                end
                J[cur_elem_start:cur_elem_end, cur_elem_start] .+= -velocity * two_over_Δz * model.M_inv[:, 1]
                J[cur_elem_start:cur_elem_end, cur_elem_end] .+= velocity * two_over_Δz * model.M_inv[:, end]
                if i > 1
                    J[cur_elem_start:cur_elem_end, element_end_index(:liquid, i - 1, comp, model)] .+= velocity * two_over_Δz * model.M_inv[:, 1]
                end
                J[cur_elem_start:cur_elem_end, cur_elem_end] .-= velocity * two_over_Δz * model.M_inv[:, end]
            end
        end
        par_factor = 3.0 / (model.par_radius * model.par_porosity)
        for comp in 1:model.n_comp
            for i in 1:n_elements
                idx_start_comp_particle = element_start_index(:particle, i, comp, model)
                idx_start_comp_liquid = element_start_index(:liquid, i, comp, model)
                for j in 0:(n_points - 1)
                    J[idx_start_comp_particle + j, idx_start_comp_particle + j] = -par_factor * model.film_diffusion_coeff[comp]
                    J[idx_start_comp_particle + j, idx_start_comp_liquid + j] = par_factor * model.film_diffusion_coeff[comp]
                end
            end
        end
        if with_isotherm
            for i in 1:n_elements
                idx_start_liquid = element_start_index(:particle, i, 1, model)
                idx_end_liquid = element_end_index(:particle, i, model.n_comp, model)
                idx_start_solid = element_start_index(:solid, i, 1, model)
                idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
                c_element = reshape(view(c, idx_start_liquid:idx_end_liquid), n_points, model.n_comp)
                cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)
                jac!(J, model.binding_model, bind_params, idx_start_solid, idx_start_liquid, idx_start_solid, c_element, cs_element)
                J[idx_start_liquid:idx_end_liquid, :] .-= β_p * J[idx_start_solid:idx_end_solid, :]
            end
        end
    end
    return nothing
end

function init_state!(u0, conc_liquid, conc_solid, model::AxialFlowModel)
    for i in 1:model.n_elements
        for comp in 1:model.n_comp
            u0[element_start_index(:liquid, i, comp, model):element_end_index(:liquid, i, comp, model)] .= conc_liquid[comp]
            u0[element_start_index(:particle, i, comp, model):element_end_index(:particle, i, comp, model)] .= conc_liquid[comp]
            u0[element_start_index(:solid, i, comp, model):element_end_index(:solid, i, comp, model)] .= conc_solid[comp]
        end
    end
    return nothing
end

function setup_problem()
    n_elements = 3
    n_degree = 1
    bind_params = ComponentArray(
        ka = [4.0, 5.5, 3.0] .* 1e-2,
        kd = [3.2, 22.0, 9.3] .* 1e-3,
        qmax = [3.0, 2.0, 6.5] .* 10.0,
        gamma = [-1.0, -0.5, 0.2],
        beta = [0.91, 0.82, 1.3],
        salt_ref = 1.0
    )
    lgl_basis = LobattoLegendre(n_degree)
    M_inv = legendre_inv_mass(lgl_basis)
    model = AxialFlowModel(
        4, 1.0 / 0.37, 0.014, fill(6.9e-6, 4), 0.37, 0.75, 45e-6,
        ModulatedBinding(4, 3),
        n_elements, n_degree + 1, (n_degree + 1) * 4, n_elements * (n_degree + 1) * 4, lgl_basis, lgl_basis.D, M_inv
    )
    save_idxs = [element_end_index(:liquid, n_elements, i, model) for i in 1:model.n_comp]
    salt_eq = 1.0
    salt_start = 1.5
    load_len = 10.0
    elute_len = 1410.0
    load_conc = [1.0, 1.0, 1.0]
    t_stop = [load_len, load_len + elute_len]

    tspan = (0.0, load_len + elute_len)
    num_dofs = model.n_points * model.n_elements * model.n_comp * 3
    u0 = zeros(Float64, num_dofs)
    init_state!(u0, [0.0, 0.0, 0.0, salt_eq], [0.0, 0.0, 0.0, salt_eq], model)
    flow_rate = 3.45 / 60 / 100

    inlet_buf = zeros(Float64, model.n_comp)
    iso_buf = zeros(Float64, model.n_points, model.binding_model.nBindingComps)

    jac_wo_iso = spzeros(length(u0), length(u0))
    jac!(jac_wo_iso, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, false)
    jac_wo_iso_T = sparse(jac_wo_iso')

    jac0 = spzeros(length(u0), length(u0))
    jac!(jac0, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, true)

    ode_jac = let model = model
        function (J, u, p, t)
            jac!(J, u, p, flow_rate, t, model, true)
            return nothing
        end
    end

    n_pts = model.n_points
    n_cmp = model.n_comp
    n_bnd = model.binding_model.nBindingComps
    n_bnd_pts = n_bnd * n_pts
    n_cmp_pts = n_cmp * n_pts
    β_p = (1.0 - model.par_porosity) / model.par_porosity

    # Pre-allocate VJP buffers
    w_buf = zeros(n_bnd_pts)
    Jv_cp_buf = zeros(n_cmp_pts)
    Jv_cs_buf = zeros(n_cmp_pts)

    vjp_fn = let jac_wo_iso_T = jac_wo_iso_T, w_buf = w_buf,
                 Jv_cp_buf = Jv_cp_buf, Jv_cs_buf = Jv_cs_buf,
                 model = model, β_p = β_p,
                 n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
        function (Jv, v, u, p, t)
            # Constant part
            mul!(Jv, jac_wo_iso_T, v)

            # Isotherm correction per element (analytical, zero-allocation)
            for i in 1:model.n_elements
                cp_s = element_start_index(:particle, i, 1, model)
                cp_e = element_end_index(:particle, i, n_cmp, model)
                cs_s = element_start_index(:solid, i, 1, model)
                cs_e = element_end_index(:solid, i, n_cmp, model)

                # w = v[solid_binding] - β_p * v[particle_binding]
                @inbounds for k in 1:n_bnd_pts
                    w_buf[k] = v[cs_s + k - 1] - β_p * v[cp_s + k - 1]
                end

                cp_mat = reshape(view(u, cp_s:cp_e), n_pts, n_cmp)
                cs_mat = reshape(view(u, cs_s:cs_e), n_pts, n_cmp)

                # Analytical J_cp^T * w and J_cs^T * w
                binding_flux_vjp_state!(Jv_cp_buf, Jv_cs_buf, w_buf,
                    model.binding_model, cp_mat, cs_mat, p)

                # Accumulate into Jv
                @inbounds for k in 1:(n_cmp * n_pts)
                    Jv[cp_s + k - 1] += Jv_cp_buf[k]
                    Jv[cs_s + k - 1] += Jv_cs_buf[k]
                end
            end
            return nothing
        end
    end

    # Pre-allocate vjp_p buffers
    w_p_buf = zeros(n_bnd_pts)
    Jv_p_elem = ComponentArray(
        ka = zeros(n_bnd), kd = zeros(n_bnd), qmax = zeros(n_bnd),
        gamma = zeros(n_bnd), beta = zeros(n_bnd), salt_ref = 0.0)

    vjp_p_fn = let w_p_buf = w_p_buf, Jv_p_elem = Jv_p_elem,
                   model = model, β_p = β_p,
                   n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
        function (out, λ, y, p, t)
            fill!(out, 0.0)

            for i in 1:model.n_elements
                cp_s = element_start_index(:particle, i, 1, model)
                cp_e = element_end_index(:particle, i, n_cmp, model)
                cs_s = element_start_index(:solid, i, 1, model)
                cs_e = element_end_index(:solid, i, n_cmp, model)

                @inbounds for k in 1:n_bnd_pts
                    w_p_buf[k] = λ[cs_s + k - 1] - β_p * λ[cp_s + k - 1]
                end

                cp_mat = reshape(view(y, cp_s:cp_e), n_pts, n_cmp)
                cs_mat = reshape(view(y, cs_s:cs_e), n_pts, n_cmp)

                # Analytical (∂binding_flux/∂p)^T * w
                binding_flux_vjp_params!(Jv_p_elem, w_p_buf,
                    model.binding_model, cp_mat, cs_mat, p)

                out .+= Jv_p_elem
            end
            return out
        end
    end

    # JVP: J * v = jac_wo_iso * v + C_iso * v (analytical)
    iso_effect_buf = zeros(n_bnd_pts)
    jvp_fn = let jac_wo_iso = jac_wo_iso, model = model, β_p = β_p,
                 iso_effect_buf = iso_effect_buf,
                 n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
        function (Jv, v, u, p, t)
            mul!(Jv, jac_wo_iso, v)

            # For JVP we still need the Jacobian blocks (can't easily avoid).
            # Use a simple forward-mode: iso_effect = J_cp*v_cp + J_cs*v_cs
            # Computed analytically per element.
            for i in 1:model.n_elements
                cp_s = element_start_index(:particle, i, 1, model)
                cp_e = element_end_index(:particle, i, n_cmp, model)
                cs_s = element_start_index(:solid, i, 1, model)
                cs_e = element_end_index(:solid, i, n_cmp, model)

                cp_mat = reshape(view(u, cp_s:cp_e), n_pts, n_cmp)
                cs_mat = reshape(view(u, cs_s:cs_e), n_pts, n_cmp)
                v_cp = view(v, cp_s:cp_e)
                v_cs = view(v, cs_s:cs_e)

                n_bind = model.binding_model.nBindingComps

                fill!(iso_effect_buf, 0.0)
                @inbounds for ii in 1:n_pts
                    qf = 1.0
                    for k in 1:n_bind
                        qf -= cs_mat[ii, k] / p.qmax[k]
                    end

                    for j in 1:n_bind
                        E_ij = exp((cp_mat[ii, n_cmp] - p.salt_ref) * p.gamma[j])
                        D_ij = (cp_mat[ii, n_cmp] / p.salt_ref)^p.beta[j]
                        ads_base = p.ka[j] * E_ij

                        # Contribution from v_cp
                        eff = ads_base * p.qmax[j] * qf * v_cp[(j - 1) * n_pts + ii]
                        eff += (ads_base * p.gamma[j] * cp_mat[ii, j] * p.qmax[j] * qf -
                                p.kd[j] * p.beta[j] * D_ij / cp_mat[ii, n_cmp] * cs_mat[ii, j]) *
                               v_cp[(n_cmp - 1) * n_pts + ii]

                        # Contribution from v_cs
                        common_cs = -ads_base * cp_mat[ii, j] * p.qmax[j]
                        for m in 1:n_bind
                            diso_dcs_m = common_cs / p.qmax[m]
                            if m == j
                                diso_dcs_m -= p.kd[j] * D_ij
                            end
                            eff += diso_dcs_m * v_cs[(m - 1) * n_pts + ii]
                        end

                        iso_effect_buf[(j - 1) * n_pts + ii] = eff
                    end
                end

                # solid_binding += iso_effect, particle_binding -= β_p * iso_effect
                @inbounds for k in 1:n_bnd_pts
                    Jv[cs_s + k - 1] += iso_effect_buf[k]
                    Jv[cp_s + k - 1] -= β_p * iso_effect_buf[k]
                end
            end
            return nothing
        end
    end

    fun = ODE.ODEFunction(
        (du, u, pp, t) -> rhs_jac_noalloc!(du, u, pp, jac_wo_iso, flow_rate, t, model,
            inlet_buf, iso_buf, load_len, elute_len, load_conc, salt_eq, salt_start);
        jac_prototype = jac0, vjp = vjp_fn, jvp = jvp_fn, jac = ode_jac,
        vjp_p = vjp_p_fn)
    prob = ODE.ODEProblem(fun, u0, tspan)
    return prob, bind_params, save_idxs, t_stop
end

prob, bind_params, save_idxs, t_stop = setup_problem()

# =====================================================================
# Benchmark
# =====================================================================
println("=" ^ 70)
println("Test: FBDF + GaussAdjoint + analytical VJP + vjp_p (zero alloc)")
println("=" ^ 70)

alg = ODE.FBDF(autodiff = false, linsolve = UMFPACKFactorization())
sense_alg = SMS.GaussAdjoint(autojacvec = SMS.EnzymeVJP())

loss_fn = let prob = prob, save_idxs = save_idxs, sense_alg = sense_alg, t_stop = t_stop, alg = alg
    function (p)
        new_prob = ODE.remake(prob, p = p)
        sol = ODE.solve(new_prob, alg,
            abstol = 1e-8, reltol = 1e-6, tstops = t_stop, saveat = 1.0, maxiters = 1e6, sensealg = sense_alg)
        L = 0.0
        for i in eachindex(sol.t)
            @views L += sum(sol.u[i][save_idxs[1:(end - 1)]])
        end
        return L
    end
end

println("\nForward solve:")
@time loss_val = loss_fn(bind_params)
@time loss_val = loss_fn(bind_params)
println("  Loss value: $loss_val")

println("\nGradient — 1st call (compile):")
@time grad1 = Zygote.gradient(loss_fn, bind_params)
println("\nGradient — 2nd call:")
@time grad2 = Zygote.gradient(loss_fn, bind_params)
println("\nGradient — 3rd call:")
@time grad3 = Zygote.gradient(loss_fn, bind_params)
println("  Gradient: $(grad3[1])")

# Quick correctness check vs ForwardDiff-based VJP
println("\n--- Correctness check ---")
expected = (ka = [-148.98160180462295, 4.682207747927423, -272.7606465430993],
    kd = [1839.0723272591213, -9.160074574142724, 883.2396698989843],
    qmax = [-0.2185663494129382, -0.0016972984600429248, -0.13072785804121395],
    gamma = [-2.716768296358113, 0.0697274702316863, -3.8895439826353875],
    beta = [2.383488922767005, -0.07674961659760442, 3.325315684679378],
    salt_ref = -20.062437960598114)
max_rel = 0.0
for field in fieldnames(typeof(expected))
    v_exp = getfield(expected, field)
    v_got = getproperty(grad3[1], field)
    rel = maximum(abs.(v_exp .- v_got) ./ max.(abs.(v_exp), 1e-15))
    global max_rel = max(max_rel, rel)
end
println("  Max relative difference vs reference: $max_rel")
if max_rel < 1e-4
    println("  PASSED ✓")
else
    println("  FAILED ✗")
end

That’s amazing! Thank you very much. :slight_smile:

I’ve upgraded to the latest versions and, using your optimized code, get gradients in 2.8 s on my machine. Really awesome!

Now, I actually require a discretization of at least n_elements = 10 and n_degree = 3.
Changing the values, I get

Forward solve:
  0.131403 seconds (19.05 k allocations: 157.979 MiB, 4.50% gc time)

Gradient — 2nd call:
┌ Warning: Verbosity toggle: max_iters
│  Interrupted. Larger maxiters is needed. If you are using an integrator for non-stiff ODEs or an automatic switching algorithm (the default), you may want to consider using a method for stiff equations. See the solver pages for more details (e.g. https://docs.sciml.ai/DiffEqDocs/stable/solvers/ode_solve/#Stiff-Problems).
└ @ SciMLBase ~/.julia/dev/SciMLBase/src/integrator_interface.jl:679
363.556598 seconds (51.52 M allocations: 458.225 GiB, 6.54% gc time, 0.08% compilation time: 90% of which was recompilation)

The second fastest version is probably easier to move things around, since it’s not an analytical VJP but it’s close. My guess is that the analytical solution baked something in you weren’t expecting. But, the same procedure to do it applies. The main optimizations here were:

  • FBDF instead of QNDF
  • Lots of improvements to the libraries
  • Manual vjp / vjp_p
# Optimized VJP: dense block computation instead of sparse setindex!
# Also adds vjp_p for parameter VJP
#
# Key insight from profiling (InternalJunk issue #9):
#   - 58% of gradient time is in the user VJP
#   - 32% of TOTAL time is sparse setindex! (binary search in CSC columns)
#   - Only 3% is the actual J^T * v sparse matmul
#
# Fix: decompose J = jac_wo_iso (constant, pre-computed) + C_iso (per-step, small dense blocks)
# Compute J^T * v = jac_wo_iso^T * v + Σ_elements (dense block)^T * w
# Avoids all sparse setindex! operations.

import OrdinaryDiffEq as ODE
import SciMLSensitivity as SMS
using LinearAlgebra, LinearSolve
using PolynomialBases, ComponentArrays, SparseArrays
import ForwardDiff
import Zygote

function legendre_inv_mass(b::NodalBasis{T}) where {T}
    V = legendre_vandermonde(b)
    for n in axes(V, 1)
        V[:, n] .= @view(V[:, n]) * sqrt((2 * n - 1) / 2)
    end
    return V * V'
end

struct ModulatedBinding
    nComp::Int
    nBindingComps::Int
end

# In-place binding_flux for the RHS (zero allocation)
function binding_flux!(iso, model::ModulatedBinding, cp, cs, p)
    n_points = size(cp, 1)
    n_bind = model.nBindingComps
    @inbounds for j in 1:n_bind
        for i in 1:n_points
            q_free = 1.0
            for k in 1:n_bind
                q_free -= cs[i, k] / p.qmax[k]
            end
            ads = p.ka[j] * exp((cp[i, end] - p.salt_ref) * p.gamma[j]) *
                  cp[i, j] * p.qmax[j] * q_free
            des = p.kd[j] * (cp[i, end] / p.salt_ref)^p.beta[j] * cs[i, j]
            iso[i, j] = ads - des
        end
    end
    return nothing
end

# Allocating binding_flux used inside ForwardDiff (returns array for AD)
function binding_flux(model::ModulatedBinding, cp::cpType, cs::csType, p::pType) where {cpType, csType, pType}
    @views begin
        q_free = 1.0 .- sum(cs[:, 1:model.nBindingComps] ./ reshape(p.qmax, 1, :), dims = 2)
        ads = reshape(p.ka, 1, :) .* exp.((cp[:, end] .- p.salt_ref) .* reshape(p.gamma, 1, :)) .* cp[:, 1:model.nBindingComps] .* reshape(p.qmax, 1, :) .* q_free
        des = reshape(p.kd, 1, :) .* (cp[:, end] ./ p.salt_ref) .^ reshape(p.beta, 1, :) .* cs[:, 1:model.nBindingComps]
        return ads .- des
    end
end

struct AxialFlowModel{BindModel}
    n_comp::Int
    cross_section_area::Float64
    length::Float64
    film_diffusion_coeff::Vector{Float64}
    col_porosity::Float64
    par_porosity::Float64
    par_radius::Float64
    binding_model::BindModel
    n_elements::Int
    n_points::Int
    n_dof_per_element::Int
    n_dof_per_phase::Int
    lgl_basis::LobattoLegendre{Float64}
    D::Array{Float64, 2}
    M_inv::Array{Float64, 2}
end

function element_start_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + (element - 1) * disc.n_points + 1
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + (component - 1) * disc.n_points + 1
end

function element_end_index(phase::Symbol, element::Int, component::Int, disc::AxialFlowModel)
    if phase == :liquid
        return (component - 1) * disc.n_elements * disc.n_points + element * disc.n_points
    elseif phase == :particle
        ip = 2
    elseif phase == :solid
        ip = 3
    end
    return (ip - 1) * disc.n_dof_per_phase + (element - 1) * disc.n_dof_per_element + component * disc.n_points
end

# Zero-allocation RHS
function rhs_jac_noalloc!(dc, c, bind_params, jac_wo_isotherm, flow_rate, t,
        model::AxialFlowModel{BindModel}, inlet_buf, iso_buf,
        load_len, elute_len, load_conc, salt_eq, salt_start) where {BindModel}
    if t < load_len
        for i in 1:length(load_conc)
            inlet_buf[i] = load_conc[i]
        end
        inlet_buf[end] = salt_eq
    elseif t < load_len + elute_len
        for i in 1:length(load_conc)
            inlet_buf[i] = 0.0
        end
        inlet_buf[end] = salt_start
    else
        fill!(inlet_buf, 0.0)
    end

    Δz = model.length / model.n_elements
    two_over_Δz = 2.0 / Δz
    velocity = flow_rate / (model.cross_section_area * model.col_porosity)
    n_points = model.n_points
    n_elements = model.n_elements

    mul!(dc, jac_wo_isotherm, c)

    @views for comp in 1:model.n_comp
        idx_start = element_start_index(:liquid, 1, comp, model)
        idx_end = element_end_index(:liquid, 1, comp, model)
        inlet_val = inlet_buf[comp]
        for j in idx_start:idx_end
            dc[j] += two_over_Δz * model.M_inv[j - idx_start + 1, 1] * velocity * inlet_val
        end
    end

    β_p = (1.0 - model.par_porosity) / model.par_porosity
    for i in 1:n_elements
        idx_start_particle = element_start_index(:particle, i, 1, model)
        idx_end_particle = element_end_index(:particle, i, model.n_comp, model)
        idx_start_solid = element_start_index(:solid, i, 1, model)
        idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
        cp_element = reshape(view(c, idx_start_particle:idx_end_particle), n_points, model.n_comp)
        cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)

        binding_flux!(iso_buf, model.binding_model, cp_element, cs_element, bind_params)

        for comp in 1:model.n_comp
            p_start = element_start_index(:particle, i, comp, model)
            s_start = element_start_index(:solid, i, comp, model)
            if comp <= model.binding_model.nBindingComps
                @inbounds for j in 1:n_points
                    dc[p_start + j - 1] -= β_p * iso_buf[j, comp]
                    dc[s_start + j - 1] = iso_buf[j, comp]
                end
            else
                @inbounds for j in 1:n_points
                    dc[s_start + j - 1] = 0.0
                end
            end
        end
    end
    return nothing
end

# Full Jacobian (used for jac_prototype and ode_jac)
function jac!(J::JType, model::ModulatedBinding, p::pType, row_offset,
        idx_start_liquid, idx_start_solid, cp, cs) where {JType, pType}
    n_points = size(cp, 1)
    n_bind = model.nBindingComps
    n_comp = model.nComp
    ForwardDiff.jacobian!(view(J, row_offset:(row_offset + n_bind * n_points - 1),
            idx_start_liquid:(idx_start_liquid + n_comp * n_points - 1)),
        cp -> reshape(binding_flux(model, reshape(cp, n_points, n_comp), cs, p), :),
        reshape(cp, :))
    ForwardDiff.jacobian!(view(J, row_offset:(row_offset + n_bind * n_points - 1),
            idx_start_solid:(idx_start_solid + n_comp * n_points - 1)),
        cs -> reshape(binding_flux(model, cp, reshape(cs, n_points, n_comp), p), :),
        reshape(cs, :))
    return nothing
end

function jac!(J::JType, c::cType, bind_params::pType, flow_rate, t,
        model::AxialFlowModel{BindModel}, with_isotherm::Bool) where {JType, cType, pType, BindModel}
    velocity = flow_rate / (model.cross_section_area * model.col_porosity)
    n_points = model.n_points
    n_elements = model.n_elements
    Δz = model.length / n_elements
    two_over_Δz = 2.0 / Δz
    β_c = (1.0 - model.col_porosity) / model.col_porosity
    β_p = (1.0 - model.par_porosity) / model.par_porosity
    J .= 0.0
    @views begin
        for comp in 1:model.n_comp
            bulk_pore_factor = β_c * 3.0 / model.par_radius * model.film_diffusion_coeff[comp]
            for i in 1:n_elements
                cur_elem_start = element_start_index(:liquid, i, comp, model)
                cur_elem_end = element_end_index(:liquid, i, comp, model)
                idx_start_cp = element_start_index(:particle, i, comp, model)
                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .= model.D
                J[cur_elem_start:cur_elem_end, cur_elem_start:cur_elem_end] .*= two_over_Δz .* (-velocity)
                for j in 0:(n_points - 1)
                    J[cur_elem_start + j, cur_elem_start + j] -= bulk_pore_factor
                    J[cur_elem_start + j, idx_start_cp + j] = bulk_pore_factor
                end
                J[cur_elem_start:cur_elem_end, cur_elem_start] .+= -velocity * two_over_Δz * model.M_inv[:, 1]
                J[cur_elem_start:cur_elem_end, cur_elem_end] .+= velocity * two_over_Δz * model.M_inv[:, end]
                if i > 1
                    J[cur_elem_start:cur_elem_end, element_end_index(:liquid, i - 1, comp, model)] .+= velocity * two_over_Δz * model.M_inv[:, 1]
                end
                J[cur_elem_start:cur_elem_end, cur_elem_end] .-= velocity * two_over_Δz * model.M_inv[:, end]
            end
        end
        par_factor = 3.0 / (model.par_radius * model.par_porosity)
        for comp in 1:model.n_comp
            for i in 1:n_elements
                idx_start_comp_particle = element_start_index(:particle, i, comp, model)
                idx_start_comp_liquid = element_start_index(:liquid, i, comp, model)
                for j in 0:(n_points - 1)
                    J[idx_start_comp_particle + j, idx_start_comp_particle + j] = -par_factor * model.film_diffusion_coeff[comp]
                    J[idx_start_comp_particle + j, idx_start_comp_liquid + j] = par_factor * model.film_diffusion_coeff[comp]
                end
            end
        end
        if with_isotherm
            for i in 1:n_elements
                idx_start_liquid = element_start_index(:particle, i, 1, model)
                idx_end_liquid = element_end_index(:particle, i, model.n_comp, model)
                idx_start_solid = element_start_index(:solid, i, 1, model)
                idx_end_solid = element_end_index(:solid, i, model.n_comp, model)
                c_element = reshape(view(c, idx_start_liquid:idx_end_liquid), n_points, model.n_comp)
                cs_element = reshape(view(c, idx_start_solid:idx_end_solid), n_points, model.n_comp)
                jac!(J, model.binding_model, bind_params, idx_start_solid, idx_start_liquid, idx_start_solid, c_element, cs_element)
                J[idx_start_liquid:idx_end_liquid, :] .-= β_p * J[idx_start_solid:idx_end_solid, :]
            end
        end
    end
    return nothing
end

function init_state!(u0, conc_liquid, conc_solid, model::AxialFlowModel)
    n_elements = model.n_elements
    n_comp = model.n_comp
    for i in 1:n_elements
        for comp in 1:n_comp
            idx_start_liquid = element_start_index(:liquid, i, comp, model)
            idx_end_liquid = element_end_index(:liquid, i, comp, model)
            idx_start_particle = element_start_index(:particle, i, comp, model)
            idx_end_particle = element_end_index(:particle, i, comp, model)
            idx_start_solid = element_start_index(:solid, i, comp, model)
            idx_end_solid = element_end_index(:solid, i, comp, model)
            u0[idx_start_liquid:idx_end_liquid] .= conc_liquid[comp]
            u0[idx_start_particle:idx_end_particle] .= conc_liquid[comp]
            u0[idx_start_solid:idx_end_solid] .= conc_solid[comp]
        end
    end
    return nothing
end

function setup_problem()
    n_elements = 3
    n_degree = 1
    bind_params = ComponentArray(
        ka = [4.0, 5.5, 3.0] .* 1e-2,
        kd = [3.2, 22.0, 9.3] .* 1e-3,
        qmax = [3.0, 2.0, 6.5] .* 10.0,
        gamma = [-1.0, -0.5, 0.2],
        beta = [0.91, 0.82, 1.3],
        salt_ref = 1.0
    )
    lgl_basis = LobattoLegendre(n_degree)
    M_inv = legendre_inv_mass(lgl_basis)
    model = AxialFlowModel(
        4, 1.0 / 0.37, 0.014, fill(6.9e-6, 4), 0.37, 0.75, 45e-6,
        ModulatedBinding(4, 3),
        n_elements, n_degree + 1, (n_degree + 1) * 4, n_elements * (n_degree + 1) * 4, lgl_basis, lgl_basis.D, M_inv
    )
    save_idxs = [element_end_index(:liquid, n_elements, i, model) for i in 1:model.n_comp]
    salt_eq = 1.0
    salt_start = 1.5
    load_len = 10.0
    elute_len = 1410.0
    load_conc = [1.0, 1.0, 1.0]
    t_stop = [load_len, load_len + elute_len]

    tspan = (0.0, load_len + elute_len)
    num_dofs = model.n_points * model.n_elements * model.n_comp * 3
    u0 = zeros(Float64, num_dofs)
    init_state!(u0, [0.0, 0.0, 0.0, salt_eq], [0.0, 0.0, 0.0, salt_eq], model)
    flow_rate = 3.45 / 60 / 100

    # Pre-allocate buffers for the RHS
    inlet_buf = zeros(Float64, model.n_comp)
    iso_buf = zeros(Float64, model.n_points, model.binding_model.nBindingComps)

    # Constant Jacobian (without isotherm) and its transpose
    jac_wo_iso = spzeros(length(u0), length(u0))
    jac!(jac_wo_iso, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, false)
    jac_wo_iso_T = sparse(jac_wo_iso')

    # Full Jacobian prototype (with isotherm, for sparsity pattern)
    jac0 = spzeros(length(u0), length(u0))
    jac!(jac0, rand(Float64, length(u0)), bind_params, flow_rate, 0.0, model, true)

    # ode_jac still uses the full sparse Jacobian (needed by FBDF Newton solver)
    ode_jac = let model = model
        function (J, u, p, t)
            jac!(J, u, p, flow_rate, t, model, true)
            return nothing
        end
    end

    # -----------------------------------------------------------------
    # Optimized VJP: dense blocks, no sparse setindex!
    #
    # J = jac_wo_iso + C_iso
    # C_iso per element:
    #   solid binding rows:    J_cp (6×8), J_cs (6×8)
    #   particle binding rows: -β_p * J_cp, -β_p * J_cs
    #
    # VJP = J^T * v = jac_wo_iso^T * v + C_iso^T * v
    # C_iso^T * v at particle_cols = J_cp^T * w
    # C_iso^T * v at solid_cols   = J_cs^T * w
    # where w = v[solid_binding] - β_p * v[particle_binding]
    # -----------------------------------------------------------------
    n_pts = model.n_points
    n_cmp = model.n_comp
    n_bnd = model.binding_model.nBindingComps
    n_bnd_pts = n_bnd * n_pts
    n_cmp_pts = n_cmp * n_pts

    # Pre-allocate dense Jacobian buffers
    J_cp_buf = zeros(n_bnd_pts, n_cmp_pts)   # 6×8
    J_cs_buf = zeros(n_bnd_pts, n_cmp_pts)   # 6×8
    w_buf = zeros(n_bnd_pts)                   # 6

    β_p = (1.0 - model.par_porosity) / model.par_porosity

    vjp_fn = let jac_wo_iso_T = jac_wo_iso_T, J_cp_buf = J_cp_buf, J_cs_buf = J_cs_buf,
                 w_buf = w_buf, model = model, β_p = β_p,
                 n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
        function (Jv, v, u, p, t)
            # Constant part: jac_wo_iso^T * v
            mul!(Jv, jac_wo_iso_T, v)

            # Isotherm correction per element (dense blocks)
            for i in 1:model.n_elements
                cp_s = element_start_index(:particle, i, 1, model)
                cp_e = element_end_index(:particle, i, n_cmp, model)
                cs_s = element_start_index(:solid, i, 1, model)
                cs_e = element_end_index(:solid, i, n_cmp, model)

                # w = v[solid_binding] - β_p * v[particle_binding]
                @inbounds for k in 1:n_bnd_pts
                    w_buf[k] = v[cs_s + k - 1] - β_p * v[cp_s + k - 1]
                end

                cp_mat = reshape(view(u, cp_s:cp_e), n_pts, n_cmp)
                cs_mat = reshape(view(u, cs_s:cs_e), n_pts, n_cmp)

                # Dense J_cp: ∂binding_flux/∂cp (n_bnd_pts × n_cmp_pts)
                ForwardDiff.jacobian!(J_cp_buf,
                    x -> reshape(binding_flux(model.binding_model, reshape(x, n_pts, n_cmp), cs_mat, p), :),
                    vec(cp_mat))

                # Dense J_cs: ∂binding_flux/∂cs (n_bnd_pts × n_cmp_pts)
                ForwardDiff.jacobian!(J_cs_buf,
                    x -> reshape(binding_flux(model.binding_model, cp_mat, reshape(x, n_pts, n_cmp), p), :),
                    vec(cs_mat))

                # Accumulate: Jv[cp_range] += J_cp^T * w, Jv[cs_range] += J_cs^T * w
                mul!(view(Jv, cp_s:cp_e), J_cp_buf', w_buf, 1.0, 1.0)
                mul!(view(Jv, cs_s:cs_e), J_cs_buf', w_buf, 1.0, 1.0)
            end
            return nothing
        end
    end

    # -----------------------------------------------------------------
    # vjp_p: parameter VJP
    # (∂f/∂p)^T * v = Σ_elements J_p^T * w
    # where J_p = ∂binding_flux/∂p (n_bnd_pts × n_params)
    # and w = v[solid_binding] - β_p * v[particle_binding] (same as state VJP)
    # -----------------------------------------------------------------
    n_params = length(bind_params)
    J_p_buf = zeros(n_bnd_pts, n_params)  # 6×16
    w_p_buf = zeros(n_bnd_pts)             # separate buffer from vjp's w_buf

    vjp_p_fn = let J_p_buf = J_p_buf, w_p_buf = w_p_buf, model = model, β_p = β_p,
                   n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
        function (out, λ, y, p, t)
            fill!(out, 0.0)
            for i in 1:model.n_elements
                cp_s = element_start_index(:particle, i, 1, model)
                cp_e = element_end_index(:particle, i, n_cmp, model)
                cs_s = element_start_index(:solid, i, 1, model)
                cs_e = element_end_index(:solid, i, n_cmp, model)

                @inbounds for k in 1:n_bnd_pts
                    w_p_buf[k] = λ[cs_s + k - 1] - β_p * λ[cp_s + k - 1]
                end

                cp_mat = reshape(view(y, cp_s:cp_e), n_pts, n_cmp)
                cs_mat = reshape(view(y, cs_s:cs_e), n_pts, n_cmp)

                # J_p: ∂binding_flux/∂p (n_bnd_pts × n_params)
                ForwardDiff.jacobian!(J_p_buf,
                    pp -> reshape(binding_flux(model.binding_model, cp_mat, cs_mat, pp), :),
                    p)

                # out += J_p^T * w
                mul!(out, J_p_buf', w_p_buf, 1.0, 1.0)
            end
            return out
        end
    end

    # JVP uses the same decomposition: J * v = jac_wo_iso * v + C_iso * v
    # C_iso * v at solid_binding = J_cp * v[cp_range] + J_cs * v[cs_range]
    # C_iso * v at particle_binding = -β_p * (J_cp * v[cp_range] + J_cs * v[cs_range])
    J_cp_jvp = zeros(n_bnd_pts, n_cmp_pts)
    J_cs_jvp = zeros(n_bnd_pts, n_cmp_pts)
    iso_effect_buf = zeros(n_bnd_pts)

    jvp_fn = let jac_wo_iso = jac_wo_iso, J_cp_jvp = J_cp_jvp, J_cs_jvp = J_cs_jvp,
                 iso_effect_buf = iso_effect_buf, model = model, β_p = β_p,
                 n_pts = n_pts, n_cmp = n_cmp, n_bnd_pts = n_bnd_pts
        function (Jv, v, u, p, t)
            mul!(Jv, jac_wo_iso, v)

            for i in 1:model.n_elements
                cp_s = element_start_index(:particle, i, 1, model)
                cp_e = element_end_index(:particle, i, n_cmp, model)
                cs_s = element_start_index(:solid, i, 1, model)
                cs_e = element_end_index(:solid, i, n_cmp, model)

                cp_mat = reshape(view(u, cp_s:cp_e), n_pts, n_cmp)
                cs_mat = reshape(view(u, cs_s:cs_e), n_pts, n_cmp)

                ForwardDiff.jacobian!(J_cp_jvp,
                    x -> reshape(binding_flux(model.binding_model, reshape(x, n_pts, n_cmp), cs_mat, p), :),
                    vec(cp_mat))
                ForwardDiff.jacobian!(J_cs_jvp,
                    x -> reshape(binding_flux(model.binding_model, cp_mat, reshape(x, n_pts, n_cmp), p), :),
                    vec(cs_mat))

                # iso_effect = J_cp * v[cp] + J_cs * v[cs]
                mul!(iso_effect_buf, J_cp_jvp, view(v, cp_s:cp_e))
                mul!(iso_effect_buf, J_cs_jvp, view(v, cs_s:cs_e), 1.0, 1.0)

                # solid_binding += iso_effect, particle_binding -= β_p * iso_effect
                @inbounds for k in 1:n_bnd_pts
                    Jv[cs_s + k - 1] += iso_effect_buf[k]
                    Jv[cp_s + k - 1] -= β_p * iso_effect_buf[k]
                end
            end
            return nothing
        end
    end

    fun = ODE.ODEFunction(
        (du, u, pp, t) -> rhs_jac_noalloc!(du, u, pp, jac_wo_iso, flow_rate, t, model,
            inlet_buf, iso_buf, load_len, elute_len, load_conc, salt_eq, salt_start);
        jac_prototype = jac0, vjp = vjp_fn, jvp = jvp_fn, jac = ode_jac,
        vjp_p = vjp_p_fn)
    prob = ODE.ODEProblem(fun, u0, tspan)
    return prob, bind_params, save_idxs, t_stop
end

prob, bind_params, save_idxs, t_stop = setup_problem()

# =====================================================================
# Benchmark
# =====================================================================
println("=" ^ 70)
println("Test: FBDF + GaussAdjoint + optimized dense-block VJP + vjp_p")
println("=" ^ 70)

alg = ODE.FBDF(autodiff = false, linsolve = UMFPACKFactorization())
sense_alg = SMS.GaussAdjoint(autojacvec = SMS.EnzymeVJP())

loss_fn = let prob = prob, save_idxs = save_idxs, sense_alg = sense_alg, t_stop = t_stop, alg = alg
    function (p)
        new_prob = ODE.remake(prob, p = p)
        sol = ODE.solve(new_prob, alg,
            abstol = 1e-8, reltol = 1e-6, tstops = t_stop, saveat = 1.0, maxiters = 1e6, sensealg = sense_alg)
        L = 0.0
        for i in eachindex(sol.t)
            @views L += sum(sol.u[i][save_idxs[1:(end - 1)]])
        end
        return L
    end
end

println("\nForward solve:")
@time loss_val = loss_fn(bind_params)
@time loss_val = loss_fn(bind_params)
println("  Loss value: $loss_val")

println("\nGradient — 1st call (compile):")
@time grad1 = Zygote.gradient(loss_fn, bind_params)
println("\nGradient — 2nd call:")
@time grad2 = Zygote.gradient(loss_fn, bind_params)
println("\nGradient — 3rd call:")
@time grad3 = Zygote.gradient(loss_fn, bind_params)
println("  Gradient: $(grad3[1])")

Maybe Enzyme can get there in the future but for now, this puts a nice speed of light that’s way faster than what you had before.