SciMLSensitivity.jl : adjoint_sensitivities and CVODE_BDF crashes for medium sized stiff ODE model

We have recently done some benchmarks to test the performance of DifferentialEquations.jl for parameter estimation for ODE models in the field of system biology. Overall, DifferentialEquations has performed great, but we have run into problems when trying to compute the gradient for stiff models via adjoint sensitivity analysis with a discrete cost function.

Below is a MVE for the Bachmann model (a stiff 21 state ODE model commonly used for benchmarks). Overall, for most parameter vectors the gradient is correctly computed when using adjoint_sensitivities, but for a subset of parameter vectors some entries in dgdu_discrete (below specifically dgdu_discrete[22] = -9e10) become very big which I think is the reason why we get the following error message from CVODE_BDF
Internal t = 120 and h = -7.58437e-17 are such that t + h = t on the next step. The solver will continue anyway.
Following this error message the code crashes (top stacktrace below).

We have tested several ODE solvers for solving the adjoint system. QNDF, FBDF run into the dtmin problem, and albeit successful TRBDF2 and the Rosenbrock solvers are too slow. We have also tested using the AMICI interface for Sundials (GitHub - AMICI-dev/AMICI: Advanced Multilanguage Interface to CVODES and IDAS) and it manages to compute the gradient via adjoint sensitivity analysis (albeit 67k integration steps are required when solving the adjoint ODE). All-to-all, I wonder if there is some option I have set incorrectly, or if there anything I can do to obtain the gradient for challenging cases like this?

MVE (running on Julia 1.8.5 on Ubuntu)

using ModelingToolkit # version = "8.46.0"
using OrdinaryDiffEq # version = "6.41.0"
using Sundials # version = "4.13.0"
using SciMLSensitivity # version = "7.19.0"

# Model name: model_Bachmann_MSB2011
# Number of parameters: 37
# Number of species: 25
function get_Bachmann_MSB2011()

    
    ModelingToolkit.@variables t p1EpoRpJAK2(t) pSTAT5(t) EpoRJAK2_CIS(t) SOCS3nRNA4(t) SOCS3RNA(t) SHP1(t) STAT5(t) EpoRJAK2(t) CISnRNA1(t) SOCS3nRNA1(t) SOCS3nRNA2(t) CISnRNA3(t) CISnRNA4(t) SOCS3(t) CISnRNA5(t) SOCS3nRNA5(t) SOCS3nRNA3(t) SHP1Act(t) npSTAT5(t) p12EpoRpJAK2(t) p2EpoRpJAK2(t) CIS(t) EpoRpJAK2(t) CISnRNA2(t) CISRNA(t)

    stateArray = [p1EpoRpJAK2, pSTAT5, EpoRJAK2_CIS, SOCS3nRNA4, SOCS3RNA, SHP1, STAT5, EpoRJAK2, CISnRNA1, SOCS3nRNA1, SOCS3nRNA2, CISnRNA3, CISnRNA4, SOCS3, CISnRNA5, SOCS3nRNA5, SOCS3nRNA3, SHP1Act, npSTAT5, p12EpoRpJAK2, p2EpoRpJAK2, CIS, EpoRpJAK2, CISnRNA2, CISRNA]


    ### Define parameters
    ModelingToolkit.@parameters STAT5Exp STAT5Imp init_SOCS3_multiplier EpoRCISRemove STAT5ActEpoR SHP1ActEpoR JAK2EpoRDeaSHP1 CISTurn SOCS3Turn init_EpoRJAK2_CIS SOCS3Inh ActD init_CIS_multiplier cyt CISRNAEqc JAK2ActEpo Epo SOCS3oe CISInh SHP1Dea SOCS3EqcOE CISRNADelay init_SHP1 CISEqcOE EpoRActJAK2 SOCS3RNAEqc CISEqc SHP1ProOE SOCS3RNADelay init_STAT5 CISoe CISRNATurn init_SHP1_multiplier init_EpoRJAK2 nuc EpoRCISInh STAT5ActJAK2 SOCS3RNATurn SOCS3Eqc

    ### Store parameters in array for ODESystem command
    parameterArray = [STAT5Exp, STAT5Imp, init_SOCS3_multiplier, EpoRCISRemove, STAT5ActEpoR, SHP1ActEpoR, JAK2EpoRDeaSHP1, CISTurn, SOCS3Turn, init_EpoRJAK2_CIS, SOCS3Inh, ActD, init_CIS_multiplier, cyt, CISRNAEqc, JAK2ActEpo, Epo, SOCS3oe, CISInh, SHP1Dea, SOCS3EqcOE, CISRNADelay, init_SHP1, CISEqcOE, EpoRActJAK2, SOCS3RNAEqc, CISEqc, SHP1ProOE, SOCS3RNADelay, init_STAT5, CISoe, CISRNATurn, init_SHP1_multiplier, init_EpoRJAK2, nuc, EpoRCISInh, STAT5ActJAK2, SOCS3RNATurn, SOCS3Eqc]

    ### Define an operator for the differentiation w.r.t. time
    D = Differential(t)

    ### Derivatives ###
    eqs = [
    D(p1EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * EpoRActJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRActJAK2 * p1EpoRpJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1))))-1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p1EpoRpJAK2 / init_SHP1)),
    D(pSTAT5) ~ +1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActJAK2 * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / (init_EpoRJAK2 * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1))))+1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActEpoR * (p12EpoRpJAK2 + p1EpoRpJAK2)^(2) / ((init_EpoRJAK2)^(2) * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (CIS * CISInh / CISEqc + 1))))-1.0 * ( 1 /cyt ) * (cyt * STAT5Imp * pSTAT5),
    D(EpoRJAK2_CIS) ~ -1.0 * ( 1 /cyt ) * (cyt * (EpoRJAK2_CIS * EpoRCISRemove * (p12EpoRpJAK2 + p1EpoRpJAK2) / init_EpoRJAK2)),
    D(SOCS3nRNA4) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA3 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA4 * SOCS3RNADelay),
    D(SOCS3RNA) ~ +1.0 * ( 1 /cyt ) * (nuc * SOCS3nRNA5 * SOCS3RNADelay)-1.0 * ( 1 /cyt ) * (cyt * SOCS3RNA * SOCS3RNATurn),
    D(SHP1) ~ -1.0 * ( 1 /cyt ) * (cyt * (SHP1 * SHP1ActEpoR * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / init_EpoRJAK2))+1.0 * ( 1 /cyt ) * (cyt * SHP1Dea * SHP1Act),
    D(STAT5) ~ -1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActJAK2 * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / (init_EpoRJAK2 * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1))))-1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActEpoR * (p12EpoRpJAK2 + p1EpoRpJAK2)^(2) / ((init_EpoRJAK2)^(2) * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (CIS * CISInh / CISEqc + 1))))+1.0 * ( 1 /cyt ) * (nuc * STAT5Exp * npSTAT5),
    D(EpoRJAK2) ~ -1.0 * ( 1 /cyt ) * (cyt * (Epo * EpoRJAK2 * JAK2ActEpo / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))+1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * JAK2EpoRDeaSHP1 * SHP1Act / init_SHP1))+1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p1EpoRpJAK2 / init_SHP1))+1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p2EpoRpJAK2 / init_SHP1))+1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p12EpoRpJAK2 / init_SHP1)),
    D(CISnRNA1) ~ +1.0 * ( 1 /nuc ) * (nuc * (CISRNAEqc * CISRNATurn * npSTAT5 * ActD / init_STAT5))-1.0 * ( 1 /nuc ) * (nuc * CISnRNA1 * CISRNADelay),
    D(SOCS3nRNA1) ~ +1.0 * ( 1 /nuc ) * (nuc * (SOCS3RNAEqc * SOCS3RNATurn * npSTAT5 * ActD / init_STAT5))-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA1 * SOCS3RNADelay),
    D(SOCS3nRNA2) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA1 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA2 * SOCS3RNADelay),
    D(CISnRNA3) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA2 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA3 * CISRNADelay),
    D(CISnRNA4) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA3 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA4 * CISRNADelay),
    D(SOCS3) ~ +1.0 * ( 1 /cyt ) * (cyt * (SOCS3RNA * SOCS3Eqc * SOCS3Turn / SOCS3RNAEqc))-1.0 * ( 1 /cyt ) * (cyt * SOCS3 * SOCS3Turn)+1.0 * ( 1 /cyt ) * (cyt * SOCS3oe * SOCS3Eqc * SOCS3Turn * SOCS3EqcOE),
    D(CISnRNA5) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA4 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA5 * CISRNADelay),
    D(SOCS3nRNA5) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA4 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA5 * SOCS3RNADelay),
    D(SOCS3nRNA3) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA2 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA3 * SOCS3RNADelay),
    D(SHP1Act) ~ +1.0 * ( 1 /cyt ) * (cyt * (SHP1 * SHP1ActEpoR * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / init_EpoRJAK2))-1.0 * ( 1 /cyt ) * (cyt * SHP1Dea * SHP1Act),
    D(npSTAT5) ~ +1.0 * ( 1 /nuc ) * (cyt * STAT5Imp * pSTAT5)-1.0 * ( 1 /nuc ) * (nuc * STAT5Exp * npSTAT5),
    D(p12EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRActJAK2 * p1EpoRpJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1))))+1.0 * ( 1 /cyt ) * (cyt * (EpoRActJAK2 * p2EpoRpJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p12EpoRpJAK2 / init_SHP1)),
    D(p2EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRpJAK2 * EpoRActJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1))))-1.0 * ( 1 /cyt ) * (cyt * (EpoRActJAK2 * p2EpoRpJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p2EpoRpJAK2 / init_SHP1)),
    D(CIS) ~ +1.0 * ( 1 /cyt ) * (cyt * (CISRNA * CISEqc * CISTurn / CISRNAEqc))-1.0 * ( 1 /cyt ) * (cyt * CIS * CISTurn)+1.0 * ( 1 /cyt ) * (cyt * CISEqc * CISTurn * CISEqcOE * CISoe),
    D(EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (Epo * EpoRJAK2 * JAK2ActEpo / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * JAK2EpoRDeaSHP1 * SHP1Act / init_SHP1))-1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * EpoRActJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRpJAK2 * EpoRActJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1)))),
    D(CISnRNA2) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA1 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA2 * CISRNADelay),
    D(CISRNA) ~ +1.0 * ( 1 /cyt ) * (nuc * CISnRNA5 * CISRNADelay)-1.0 * ( 1 /cyt ) * (cyt * CISRNA * CISRNATurn)
    ]

    @named sys = ODESystem(eqs, t, stateArray, parameterArray)

    ### Initial species concentrations ###
    initialSpeciesValues = [
    p1EpoRpJAK2 => 0.0,
    pSTAT5 => 0.0,
    EpoRJAK2_CIS => init_EpoRJAK2_CIS,
    SOCS3nRNA4 => 0.0,
    SOCS3RNA => 0.0,
    SHP1 => init_SHP1 * (init_SHP1_multiplier * SHP1ProOE + 1),
    STAT5 => init_STAT5,
    EpoRJAK2 => init_EpoRJAK2,
    CISnRNA1 => 0.0,
    SOCS3nRNA1 => 0.0,
    SOCS3nRNA2 => 0.0,
    CISnRNA3 => 0.0,
    CISnRNA4 => 0.0,
    SOCS3 => init_SOCS3_multiplier * SOCS3EqcOE * SOCS3Eqc,
    CISnRNA5 => 0.0,
    SOCS3nRNA5 => 0.0,
    SOCS3nRNA3 => 0.0,
    SHP1Act => 0.0,
    npSTAT5 => 0.0,
    p12EpoRpJAK2 => 0.0,
    p2EpoRpJAK2 => 0.0,
    CIS => init_CIS_multiplier * CISEqc * CISEqcOE,
    EpoRpJAK2 => 0.0,
    CISnRNA2 => 0.0,
    CISRNA => 0.0
    ]

    ### SBML file parameter values ###
    trueParameterValues = [
    STAT5Exp => 0.0745150819016423,
    STAT5Imp => 0.0268865083829685,
    init_SOCS3_multiplier => 0.0,
    EpoRCISRemove => 5.42980693903448,
    STAT5ActEpoR => 38.9957991073948,
    SHP1ActEpoR => 0.00100000000000006,
    JAK2EpoRDeaSHP1 => 142.72332309738,
    CISTurn => 0.0083988695167017,
    SOCS3Turn => 9999.99999999912,
    init_EpoRJAK2_CIS => 0.0,
    SOCS3Inh => 10.4078649133666,
    ActD => 1.25e-7,
    init_CIS_multiplier => 0.0,
    cyt => 0.4,
    CISRNAEqc => 1.0,
    JAK2ActEpo => 633167.430600806,
    Epo => 1.25e-7,
    SOCS3oe => 1.25e-7,
    CISInh => 7.85269991450496e8,
    SHP1Dea => 0.00816220490950374,
    SOCS3EqcOE => 0.679165515556864,
    CISRNADelay => 0.14477775532111,
    init_SHP1 => 26.7251164277109,
    CISEqcOE => 0.530264447119609,
    EpoRActJAK2 => 0.267304849333058,
    SOCS3RNAEqc => 1.0,
    CISEqc => 432.860413434913,
    SHP1ProOE => 2.82568153411555,
    SOCS3RNADelay => 1.06458446742251,
    init_STAT5 => 79.75363993771,
    CISoe => 1.25e-7,
    CISRNATurn => 999.999999999946,
    init_SHP1_multiplier => 1.0,
    init_EpoRJAK2 => 3.97622369384192,
    nuc => 0.275,
    EpoRCISInh => 999999.999999912,
    STAT5ActJAK2 => 0.0781068855795467,
    SOCS3RNATurn => 0.00830917643120369,
    SOCS3Eqc => 173.64470023136
    ]

    return sys, initialSpeciesValues, trueParameterValues
end

# Observable function G
function computeG(u, p, t, σ)
    h = u[22]
    dataObserved = [29.061794973113646, 26.097567191289983, 19.65239347179184]
    σ = exp10(-3.0)
    G = 0.0
    for i in eachindex(dataObserved)
        G += log(σ) + 0.5*log(2*pi) + log(log(10)) + log(10)*log10(dataObserved[i]) + 0.5*(log10(h) - log10(dataObserved[i]) / σ)^2
    end
    return G
end


function compute∂G∂u(out, u, p, t, i)
    dataObserved = [29.061794973113646, 26.097567191289983, 19.65239347179184]
    σ = exp10(-3.0)
    h = u[22]
    out .= 0.0
    for i in eachindex(dataObserved)
        ∂h∂u = zeros(length(u))
        ∂h∂u[22] = 1
        ∂h∂u .*= (1 / (log(10) * h))  * (log10(exp10(h)) - log10(dataObserved[i])) / σ^2 
        out .+= ∂h∂u
    end
end

sys, stateMap, parameterMap = get_Bachmann_MSB2011()
odeProblem = ODEProblem{true, SciMLBase.FullSpecialize}(sys, stateMap, [0.0, 130.0], parameterMap, jac=true)

# Parameter vector and initial value vector that crashes 
p = [26.56087782946684, 0.0011497569953977356, 0.0, 0.037649358067924674, 0.11497569953977356, 2.848035868435802, 1.291549665014884, 7.56463327554629, 0.02595024211399736, 0.0, 0.007054802310718645, 1.0, 0.0, 0.4, 1.0, 1.629750834620647e6, 1.25e-7, 0.0, 2.1544346900318843, 0.024770763559917114, 0.001747528400007683, 0.5336699231206307, 0.014174741629268055, 0.3511191734215131, 104.7615752789664, 1.0, 0.0016297508346206436, 1.232846739442066, 1.6297508346206435, 0.0657933224657568, 0.0, 0.005336699231206312, 0.0, 13.219411484660288, 0.275, 351119.17342151277, 0.003511191734215131, 0.6135907273413176, 0.1519911082952933]
u0 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.014174741629268055, 0.0657933224657568, 13.219411484660288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
odeProblem.p .= p
odeProblem.u0 .= u0

solForward = solve(odeProblem, CVODE_BDF(), abstol=1e-8, reltol=1e-8)
du, dp = adjoint_sensitivities(solForward, CVODE_BDF(), dgdu_discrete=compute∂G∂u, t=[120.0], 
                               sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()), 
                               abstol=1e-8, reltol=1e-8)

Stacktrace (only the top as else I hit the character limit for the post)

ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [0]
Stacktrace:
  [1] getindex
    @ ./array.jl:924 [inlined]

If CVODE_BDF is failing but you say Sundials worked in AMICI… both are using the same solver. So are you sure you translated it correctly? What’s your validation that the two are exactly the same: a random u and p test? Can you share the validation script?

Forward mode will be a lot faster for this case. But from the sound of it, it seems like the ODE is just implemented incorrectly right now and so it’s taking 67k steps? That’s pretty absurd for a case like this so I’d just fix that first.

That’s also just a very slow choice. Just leave it blank and let it choose the sensealg.

Thanks for the reply!

I am confident the model is implemented correctly. Below is a code snippet that compares the cost, gradient (using adjoint sensitivity analysis, ForwardDiff and for completeness FiniteDifferences.jl) for the full Bachmann model against the cost and gradient computed by PyPesto (which uses AMICI for sensitivity computations GitHub - ICB-DCM/pyPESTO: python Parameter EStimation TOolbox). Moreover, when we run parameter estimation for the Bachmann model (where we compute the gradient via ForwardDiff.jl) we several times, from random starting points, arrive at the reported minimum.

The Bachmann model consists of several experimental conditions (we must simulate the ODE multiple times with different initial values), and we have several observables at different time-points. In the posted MVE I have extracted the experimental condition and observable that causes adjoint_sensitivities to crash. I believe the code crashes owing to a bad parameter vector causing dgdu_discrete to becomes ( specifically dgdu_discrete[22] = -9e10), causing the λ to have widely varying scales between its states (and this multiscale issue can be why AMICI require 67k steps in order to solve for λ).

And you are correct, for this problem ForwardDiff.jl is adequate (almost as fast as adjoint with EnzymeVJP). However, for bigger system biology models we are going to need adjoint sensitivity analysis. As we typically do parameter estimation perform by multi-start local gradient based optimization failing to compute the gradient is quite detrimental. As the problem above for adjoint_sensitivities manifested for a medium sized stiff model I think it will likely manifest for larger stiff models (hence would be great to have an approach to handle cases like these).

Code-snippet. If you want to run the code you must use the PEtab importer (PEtab – a data format for specifying parameter estimation problems in systems biology — PEtab latest documentation) we are currently working on (a master student of mine started the project which is very much under construction GitHub - CleonII/Master-Thesis, the code file can be found at Master-Thesis/Bachmann.jl at main · CleonII/Master-Thesis · GitHub)

using ModelingToolkit 
using DifferentialEquations
using DataFrames
using CSV 
using ForwardDiff
using ReverseDiff
using StatsBase
using Random
using LinearAlgebra
using Distributions
using Printf
using Zygote
using SciMLSensitivity
using Sundials
using YAML
using FiniteDifferences


# Relevant PeTab structs for compuations 
include(joinpath(pwd(), "src", "PeTab_structs.jl"))

# PeTab importer to get cost, grad etc 
include(joinpath(pwd(), "src", "Create_PEtab_model.jl"))

# For converting to SBML 
include(joinpath(pwd(), "src", "SBML", "SBML_to_ModellingToolkit.jl"))

include(joinpath(pwd(), "tests", "Common.jl"))


petabModel = readPEtabModel(joinpath(@__DIR__, "Bachmann", "Bachmann_MSB2011.yaml"), forceBuildJuliaFiles=false)

solver, tol = Rodas5P(), 1e-9
petabProblem1 = setUpPEtabODEProblem(petabModel, solver, solverAbsTol=tol, solverRelTol=tol, 
                                     odeSolverAdjoint=solver, solverAdjointAbsTol=tol, solverAdjointRelTol=tol,
                                     sensealgAdjoint=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)))

# File with random parameter values 
paramVals = CSV.read(pwd() * "/tests/Bachmann/Params.csv", DataFrame)
paramMat = paramVals[!, Not([:Id, :SOCS3RNAEqc, :CISRNAEqc])]

# PyPesto hessian, gradient and cost values for the random parameter vectors 
costPython = (CSV.read(pwd() * "/tests/Bachmann/Cost.csv", DataFrame))[!, :Cost]
gradPythonMat = CSV.read(pwd() * "/tests/Bachmann/Grad.csv", DataFrame)
gradPythonMat = gradPythonMat[!, Not([:Id, :SOCS3RNAEqc, :CISRNAEqc])]
 
paramEstNames = string.(petabProblem1.θ_estNames)
# For correct indexing when comparing gradient (we have different ordering compared to PyPesto)
iUse = [findfirst(x -> x == paramEstNames[i], names(paramMat)) for i in eachindex(paramEstNames)]
nParam = ncol(paramMat)

i = 3
p = collect(paramMat[i, iUse])
referenceCost = costPython[i] # Cost from PyPesto
referenceGradient = collect(gradPythonMat[i, iUse]) # Cost from 

cost = petabProblem1.computeCost(p)
@printf("Cost python = %.3e and cost Julia = %.3e\n", referenceCost, cost)
@printf("abs(Difference cost) = %.3e\n", abs(referenceCost - cost))

# Gradient via Forward-mode automatic differentitation 
gradientForwardDiff = zeros(length(p))        
petabProblem1.computeGradientAutoDiff(gradientForwardDiff, p)
@printf("norm(gradientPyPesto - gradientForwardDiffJulia) = %.3e\n", norm(gradientForwardDiff - referenceGradient))

# Gradient via adjoint sensitivity analysis
gradientAdjoint = zeros(length(p))
petabProblem1.computeGradientAdjoint(gradientAdjoint, p)
@printf("norm(gradientPyPesto - gradientAdjointJulia) = %.3e\n", norm(gradientAdjoint - referenceGradient))
        
# Gradient from FiniteDifferences 
gradientFinite = FiniteDifferences.grad(central_fdm(5, 1), petabProblem1.computeCost, p)[1]
@printf("norm(gradientPyPesto - gradientFiniteDifferences) = %.3e\n", norm(gradientFinite - referenceGradient))

println("Gradient Adjoint = ", gradientAdjoint)
println("Gradient PyPesto = ", referenceGradient)

with output

Cost python = 4.063e+03 and cost Julia = 4.063e+03
abs(Difference cost) = 7.118e-07
norm(gradientPyPesto - gradientForwardDiffJulia) = 1.826e-04
norm(gradientPyPesto - gradientAdjointJulia) = 1.802e-04
norm(gradientPyPesto - gradientFiniteDifferences) = 1.820e-04
Gradient Adjoint = [-0.001742872749390386, 23.182193161231172, 26.667024348020682, -1208.1392036144641, -0.044490741702237976, -1616.968211386588, -18.218890222459788, -0.03150640831308554, 0.49554683059008436, -430.87935281172435, 128.87811202482897, 128.8073506405033, -11.547559160284457, 1.396664087168968, -7.665005991960116e-6, -0.09766694700859709, 21.698060724580667, 7.249454614081832, 21.409178206215117, 0.00015900271607536542, -61.77100236784556, -496.8430766900005, 1454.3254497506066, -1170.6909806871959, 1.0933429656735706e-9, -2.3657656100763587e-8, -1.3301490462467032e-6, -7.731475689693673e-5, -16999.400179423406, 62.458003373494115, 29.36127205996182, 136.60147747827827, -0.6171528226705253, -113.70599997716141, 6.87865522912127, 6.834873393661888, 26.639151770254436, 6.809079702860347, 13.256034121144639, 459.6378273699804, 2.4707529258324135, 349.0033439753882, 0.2648887481005327, 517.2219025036472, 0.22218791774224267, 1494.6079704383305, 2.016931701467492, 332.6247586717753, 28.20760159906679, -1039.135512202305, -1008.0582892017982, -0.15162841512057865, -0.1325395183121428, -0.10368814165606167, 10.786962301857617, 4.811572486407348, 4.985814949821249, 3.5038531160602098, 5.321842235168024, 1.1497618919926667, 6.7905724834665335, 3.377951436958088, 6.3849128420778705, 10.670122539973717, 5.780751124589986, 4.978514464316092, 12.335203706568288, 6.805909227510565, 12.082122059188235, 21.72784743036225, 5.906622190496201, 2.3732623925916605, 12.37643017870708, 1.361191489426026, 6.307397148653578, 0.3920820907506159, 5.752727282226018, 0.4288666665069182, 7.937531190797928, 3.2268779069309423, 7.165822244788378, 1.3904137839748236, 15.103723515996471, 1.8896015103305503, 15.565917741064851, 8.51766380815996, 5.871317873240681, 0.4279304926071382, 9.883427290269266, 1.0924221994470078, 5.3768631679382715, 0.26673505272258063, 5.213322347478952, 0.9382785922669036, -0.07391141977481899, 12.941364500818205, 2.0492294092348193, 4.0222063782901465, 0.3334912608962849, 14.059119670362621, 1.8366935181366202, 0.020172355736484654, 5.917200479235095, 0.5992255161651585, 2.843102324633841, 0.018986272440006613, -0.10640227141043175, -0.07665400419428463, -0.1289355418322084, 6.467619144390042, 1.533522307109798, 1.6678832592905413, 1.597177102968559]
Gradient PyPesto = [-0.0017427767483265042, 23.182193161238555, 26.6670510799979, -1208.139089020869, -0.04449044543696711, -1616.9681853768648, -18.21896401097584, -0.0315064338732992, 0.49554709785503803, -430.87942964647067, 128.87813973050484, 128.8073783396324, -11.547559623304947, 1.3966640868882763, -7.66798931894116e-6, -0.09766694575564007, 21.69806086169711, 7.249454650667229, 21.409178330797637, 0.0001590027175223748, -61.77106297812481, -496.8430763802603, 1454.3254500738078, -1170.6909444156354, 4.758723841742529e-13, -1.9735372506003256e-8, 2.2457634515700627e-6, -7.731475648743016e-5, -16999.400183093712, 62.4580033732895, 29.361272066276587, 136.60147748181004, -0.6171528226496582, -113.70599303921081, 6.878655229121269, 6.834873393661853, 26.639151770862355, 6.809079702860347, 13.256034121144632, 459.63782735311565, 2.4707529129475856, 349.0033439715811, 0.2648887480410858, 517.2219024959204, 0.22218791680976407, 1494.6079704307463, 2.016931690868201, 332.62475867709026, 28.207601589961683, -1039.1355125397265, -1008.0582895392196, -0.15162841517288983, -0.1325395183526757, -0.10368814169368626, 10.786962299811659, 4.811572490157788, 4.985814954822384, 3.50385310285312, 5.321842234303718, 1.1497618945100356, 6.790572482388598, 3.3779514397942907, 6.384912841989932, 10.670122540525151, 5.780751125679641, 4.978514459098822, 12.335203702555713, 6.8059092363205425, 12.082122058500968, 21.727847431604637, 5.906622194102209, 2.373262384457991, 12.37643017811895, 1.361191490385354, 6.307397148452184, 0.3920820913226305, 5.752727282204861, 0.4288666669155108, 7.937531190796522, 3.2268779069517635, 7.165822245497411, 1.3904137825645682, 15.103723515048582, 1.88960151232024, 15.565917740911573, 8.517663808418005, 5.871317873245986, 0.4279304928879557, 9.883427289968784, 1.0924221978787483, 5.376863167802295, 0.26673505206003834, 5.213322347383734, 0.9382785897492021, -0.0739114174413833, 12.94136450045551, 2.049229409952603, 4.022206378047826, 0.3334912603386871, 14.059117195375393, 1.8366935179768082, 0.02017235567474185, 5.917200479370551, 0.5992255159287493, 2.843102324546546, 0.01898627242836005, -0.1064022713349269, -0.07665400413296508, -0.12893554173904256, 6.467619144390046, 1.5335223079830143, 1.6678832603482263, 1.5971771041207057]

I’m a bit confused now, what exactly is the issue? It sounds like you now have it working with EnzymeVJP which is what I would expect to work out well (or just use the defaults)? Is it that adjoints (with EnzymeVJP) can be unstable with CVODE_BDF/QNDF/FBDF but this isn’t seen with the Base Sundials?

That’s well above the tolerances? Is PyPesto not respecting the tolerance on the gradient calculation? Since they all level off at the same point I would suspect it’s PyPesto that’s only correct to 1e-4 here, otherwise they would all be different values. That would explain a memory difference too since the adjoint is much larger than the forward pass, and so if the adjoint is taking a lot more steps for us then it’s a much bigger memory difference.

Try QuadratureAdjoint. That would boost the memory usage but get around cubic scaling and some numerical issues that could pop up here.

https://aip.scitation.org/doi/10.1063/5.0060697

We should have a paper in the near future that describes a new adjoint technique that completely removes the memory issue too.

Sorry for being unclear. The random parameter vector in my reply is not the same parameter vector that causes the crash in the original MVE. The parameter vector in the reply was chosen just to show that the implementation of the model is correct as we obtain similar values to PyPesto. Also, for the case in the reply they control the gradient accuracy via the tolerances of the ODE solver when solving the forward sensitivity equations.

So the issue is that for the parameter vector in the MVE (top of the post) adjoint_sensitivity analysis crashes, and AMICI manages to compute the gradient (albeit it takes time and many integration steps). Both QuadratureAdjoint and InterpolatingAdjoint crash, regardless of whether I use either ReverseDiffVJP or EnzymeVJP. Also QNDF, FBDF, TRBDF2 fail, while Rosenbrock solvers like Rodas5P can solve the problem (but they are slow for models of this size).

In the same way at the same steps?

And to clarify, AMICI uses Sundials’ CVODES adjoint methods?

What are the results if you just use the AD interface instead of writing compute∂G∂u by hand?

Hi,

Yes, AMICI uses Sundials CVODES adjoint methods.

Yes both QuadratureAdjoint and InterpolatingAdjoint fail at 120 (at the time point for which we have data).

When I use the AD interface via Zygote the code also crashes at the exact same time-point, t=120 (see code below).

using ModelingToolkit
using OrdinaryDiffEq
using Sundials
using SciMLSensitivity
using Zygote

#= 
    Minimal working example on how adjoint sensitivity analysis crashes on the System-Biology Bachmann_MSB2011 
    model ()
=#


# Model name: model_Bachmann_MSB2011
# Number of parameters: 37
# Number of species: 25
function get_Bachmann_MSB2011()

    
    ModelingToolkit.@variables t p1EpoRpJAK2(t) pSTAT5(t) EpoRJAK2_CIS(t) SOCS3nRNA4(t) SOCS3RNA(t) SHP1(t) STAT5(t) EpoRJAK2(t) CISnRNA1(t) SOCS3nRNA1(t) SOCS3nRNA2(t) CISnRNA3(t) CISnRNA4(t) SOCS3(t) CISnRNA5(t) SOCS3nRNA5(t) SOCS3nRNA3(t) SHP1Act(t) npSTAT5(t) p12EpoRpJAK2(t) p2EpoRpJAK2(t) CIS(t) EpoRpJAK2(t) CISnRNA2(t) CISRNA(t)

    stateArray = [p1EpoRpJAK2, pSTAT5, EpoRJAK2_CIS, SOCS3nRNA4, SOCS3RNA, SHP1, STAT5, EpoRJAK2, CISnRNA1, SOCS3nRNA1, SOCS3nRNA2, CISnRNA3, CISnRNA4, SOCS3, CISnRNA5, SOCS3nRNA5, SOCS3nRNA3, SHP1Act, npSTAT5, p12EpoRpJAK2, p2EpoRpJAK2, CIS, EpoRpJAK2, CISnRNA2, CISRNA]


    ### Define parameters
    ModelingToolkit.@parameters STAT5Exp STAT5Imp init_SOCS3_multiplier EpoRCISRemove STAT5ActEpoR SHP1ActEpoR JAK2EpoRDeaSHP1 CISTurn SOCS3Turn init_EpoRJAK2_CIS SOCS3Inh ActD init_CIS_multiplier cyt CISRNAEqc JAK2ActEpo Epo SOCS3oe CISInh SHP1Dea SOCS3EqcOE CISRNADelay init_SHP1 CISEqcOE EpoRActJAK2 SOCS3RNAEqc CISEqc SHP1ProOE SOCS3RNADelay init_STAT5 CISoe CISRNATurn init_SHP1_multiplier init_EpoRJAK2 nuc EpoRCISInh STAT5ActJAK2 SOCS3RNATurn SOCS3Eqc

    ### Store parameters in array for ODESystem command
    parameterArray = [STAT5Exp, STAT5Imp, init_SOCS3_multiplier, EpoRCISRemove, STAT5ActEpoR, SHP1ActEpoR, JAK2EpoRDeaSHP1, CISTurn, SOCS3Turn, init_EpoRJAK2_CIS, SOCS3Inh, ActD, init_CIS_multiplier, cyt, CISRNAEqc, JAK2ActEpo, Epo, SOCS3oe, CISInh, SHP1Dea, SOCS3EqcOE, CISRNADelay, init_SHP1, CISEqcOE, EpoRActJAK2, SOCS3RNAEqc, CISEqc, SHP1ProOE, SOCS3RNADelay, init_STAT5, CISoe, CISRNATurn, init_SHP1_multiplier, init_EpoRJAK2, nuc, EpoRCISInh, STAT5ActJAK2, SOCS3RNATurn, SOCS3Eqc]

    ### Define an operator for the differentiation w.r.t. time
    D = Differential(t)

    ### Derivatives ###
    eqs = [
    D(p1EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * EpoRActJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRActJAK2 * p1EpoRpJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1))))-1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p1EpoRpJAK2 / init_SHP1)),
    D(pSTAT5) ~ +1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActJAK2 * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / (init_EpoRJAK2 * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1))))+1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActEpoR * (p12EpoRpJAK2 + p1EpoRpJAK2)^(2) / ((init_EpoRJAK2)^(2) * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (CIS * CISInh / CISEqc + 1))))-1.0 * ( 1 /cyt ) * (cyt * STAT5Imp * pSTAT5),
    D(EpoRJAK2_CIS) ~ -1.0 * ( 1 /cyt ) * (cyt * (EpoRJAK2_CIS * EpoRCISRemove * (p12EpoRpJAK2 + p1EpoRpJAK2) / init_EpoRJAK2)),
    D(SOCS3nRNA4) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA3 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA4 * SOCS3RNADelay),
    D(SOCS3RNA) ~ +1.0 * ( 1 /cyt ) * (nuc * SOCS3nRNA5 * SOCS3RNADelay)-1.0 * ( 1 /cyt ) * (cyt * SOCS3RNA * SOCS3RNATurn),
    D(SHP1) ~ -1.0 * ( 1 /cyt ) * (cyt * (SHP1 * SHP1ActEpoR * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / init_EpoRJAK2))+1.0 * ( 1 /cyt ) * (cyt * SHP1Dea * SHP1Act),
    D(STAT5) ~ -1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActJAK2 * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / (init_EpoRJAK2 * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1))))-1.0 * ( 1 /cyt ) * (cyt * (STAT5 * STAT5ActEpoR * (p12EpoRpJAK2 + p1EpoRpJAK2)^(2) / ((init_EpoRJAK2)^(2) * (SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (CIS * CISInh / CISEqc + 1))))+1.0 * ( 1 /cyt ) * (nuc * STAT5Exp * npSTAT5),
    D(EpoRJAK2) ~ -1.0 * ( 1 /cyt ) * (cyt * (Epo * EpoRJAK2 * JAK2ActEpo / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))+1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * JAK2EpoRDeaSHP1 * SHP1Act / init_SHP1))+1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p1EpoRpJAK2 / init_SHP1))+1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p2EpoRpJAK2 / init_SHP1))+1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p12EpoRpJAK2 / init_SHP1)),
    D(CISnRNA1) ~ +1.0 * ( 1 /nuc ) * (nuc * (CISRNAEqc * CISRNATurn * npSTAT5 * ActD / init_STAT5))-1.0 * ( 1 /nuc ) * (nuc * CISnRNA1 * CISRNADelay),
    D(SOCS3nRNA1) ~ +1.0 * ( 1 /nuc ) * (nuc * (SOCS3RNAEqc * SOCS3RNATurn * npSTAT5 * ActD / init_STAT5))-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA1 * SOCS3RNADelay),
    D(SOCS3nRNA2) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA1 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA2 * SOCS3RNADelay),
    D(CISnRNA3) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA2 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA3 * CISRNADelay),
    D(CISnRNA4) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA3 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA4 * CISRNADelay),
    D(SOCS3) ~ +1.0 * ( 1 /cyt ) * (cyt * (SOCS3RNA * SOCS3Eqc * SOCS3Turn / SOCS3RNAEqc))-1.0 * ( 1 /cyt ) * (cyt * SOCS3 * SOCS3Turn)+1.0 * ( 1 /cyt ) * (cyt * SOCS3oe * SOCS3Eqc * SOCS3Turn * SOCS3EqcOE),
    D(CISnRNA5) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA4 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA5 * CISRNADelay),
    D(SOCS3nRNA5) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA4 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA5 * SOCS3RNADelay),
    D(SOCS3nRNA3) ~ +1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA2 * SOCS3RNADelay)-1.0 * ( 1 /nuc ) * (nuc * SOCS3nRNA3 * SOCS3RNADelay),
    D(SHP1Act) ~ +1.0 * ( 1 /cyt ) * (cyt * (SHP1 * SHP1ActEpoR * (EpoRpJAK2 + p12EpoRpJAK2 + p1EpoRpJAK2 + p2EpoRpJAK2) / init_EpoRJAK2))-1.0 * ( 1 /cyt ) * (cyt * SHP1Dea * SHP1Act),
    D(npSTAT5) ~ +1.0 * ( 1 /nuc ) * (cyt * STAT5Imp * pSTAT5)-1.0 * ( 1 /nuc ) * (nuc * STAT5Exp * npSTAT5),
    D(p12EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRActJAK2 * p1EpoRpJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1))))+1.0 * ( 1 /cyt ) * (cyt * (EpoRActJAK2 * p2EpoRpJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p12EpoRpJAK2 / init_SHP1)),
    D(p2EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRpJAK2 * EpoRActJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1))))-1.0 * ( 1 /cyt ) * (cyt * (EpoRActJAK2 * p2EpoRpJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (JAK2EpoRDeaSHP1 * SHP1Act * p2EpoRpJAK2 / init_SHP1)),
    D(CIS) ~ +1.0 * ( 1 /cyt ) * (cyt * (CISRNA * CISEqc * CISTurn / CISRNAEqc))-1.0 * ( 1 /cyt ) * (cyt * CIS * CISTurn)+1.0 * ( 1 /cyt ) * (cyt * CISEqc * CISTurn * CISEqcOE * CISoe),
    D(EpoRpJAK2) ~ +1.0 * ( 1 /cyt ) * (cyt * (Epo * EpoRJAK2 * JAK2ActEpo / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * JAK2EpoRDeaSHP1 * SHP1Act / init_SHP1))-1.0 * ( 1 /cyt ) * (cyt * (EpoRpJAK2 * EpoRActJAK2 / (SOCS3 * SOCS3Inh / SOCS3Eqc + 1)))-1.0 * ( 1 /cyt ) * (cyt * (3 * EpoRpJAK2 * EpoRActJAK2 / ((SOCS3 * SOCS3Inh / SOCS3Eqc + 1) * (EpoRCISInh * EpoRJAK2_CIS + 1)))),
    D(CISnRNA2) ~ +1.0 * ( 1 /nuc ) * (nuc * CISnRNA1 * CISRNADelay)-1.0 * ( 1 /nuc ) * (nuc * CISnRNA2 * CISRNADelay),
    D(CISRNA) ~ +1.0 * ( 1 /cyt ) * (nuc * CISnRNA5 * CISRNADelay)-1.0 * ( 1 /cyt ) * (cyt * CISRNA * CISRNATurn)
    ]

    @named sys = ODESystem(eqs, t, stateArray, parameterArray)

    ### Initial species concentrations ###
    initialSpeciesValues = [
    p1EpoRpJAK2 => 0.0,
    pSTAT5 => 0.0,
    EpoRJAK2_CIS => init_EpoRJAK2_CIS,
    SOCS3nRNA4 => 0.0,
    SOCS3RNA => 0.0,
    SHP1 => init_SHP1 * (init_SHP1_multiplier * SHP1ProOE + 1),
    STAT5 => init_STAT5,
    EpoRJAK2 => init_EpoRJAK2,
    CISnRNA1 => 0.0,
    SOCS3nRNA1 => 0.0,
    SOCS3nRNA2 => 0.0,
    CISnRNA3 => 0.0,
    CISnRNA4 => 0.0,
    SOCS3 => init_SOCS3_multiplier * SOCS3EqcOE * SOCS3Eqc,
    CISnRNA5 => 0.0,
    SOCS3nRNA5 => 0.0,
    SOCS3nRNA3 => 0.0,
    SHP1Act => 0.0,
    npSTAT5 => 0.0,
    p12EpoRpJAK2 => 0.0,
    p2EpoRpJAK2 => 0.0,
    CIS => init_CIS_multiplier * CISEqc * CISEqcOE,
    EpoRpJAK2 => 0.0,
    CISnRNA2 => 0.0,
    CISRNA => 0.0
    ]

    ### SBML file parameter values ###
    trueParameterValues = [
    STAT5Exp => 0.0745150819016423,
    STAT5Imp => 0.0268865083829685,
    init_SOCS3_multiplier => 0.0,
    EpoRCISRemove => 5.42980693903448,
    STAT5ActEpoR => 38.9957991073948,
    SHP1ActEpoR => 0.00100000000000006,
    JAK2EpoRDeaSHP1 => 142.72332309738,
    CISTurn => 0.0083988695167017,
    SOCS3Turn => 9999.99999999912,
    init_EpoRJAK2_CIS => 0.0,
    SOCS3Inh => 10.4078649133666,
    ActD => 1.25e-7,
    init_CIS_multiplier => 0.0,
    cyt => 0.4,
    CISRNAEqc => 1.0,
    JAK2ActEpo => 633167.430600806,
    Epo => 1.25e-7,
    SOCS3oe => 1.25e-7,
    CISInh => 7.85269991450496e8,
    SHP1Dea => 0.00816220490950374,
    SOCS3EqcOE => 0.679165515556864,
    CISRNADelay => 0.14477775532111,
    init_SHP1 => 26.7251164277109,
    CISEqcOE => 0.530264447119609,
    EpoRActJAK2 => 0.267304849333058,
    SOCS3RNAEqc => 1.0,
    CISEqc => 432.860413434913,
    SHP1ProOE => 2.82568153411555,
    SOCS3RNADelay => 1.06458446742251,
    init_STAT5 => 79.75363993771,
    CISoe => 1.25e-7,
    CISRNATurn => 999.999999999946,
    init_SHP1_multiplier => 1.0,
    init_EpoRJAK2 => 3.97622369384192,
    nuc => 0.275,
    EpoRCISInh => 999999.999999912,
    STAT5ActJAK2 => 0.0781068855795467,
    SOCS3RNATurn => 0.00830917643120369,
    SOCS3Eqc => 173.64470023136
    ]

    return sys, initialSpeciesValues, trueParameterValues
end


# Observable function G
function computeG(u0, p)

    sol = solve(odeProblem, CVODE_BDF(), abstol=1e-8, reltol=1e-8, u0=u0, p=p, saveat=[120.0], sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP()))
    u = sol[:, 1]
    h = u[22]
    dataObserved = [29.061794973113646, 26.097567191289983, 19.65239347179184]
    σ = exp10(-3.0)
    G = 0.0
    for i in eachindex(dataObserved)
        G += log(σ) + 0.5*log(2*pi) + log(log(10)) + log(10)*log10(dataObserved[i]) + 0.5*(log10(h) - log10(dataObserved[i]) / σ)^2
    end
    return G
end

sys, stateMap, parameterMap = get_Bachmann_MSB2011()
odeProblem = ODEProblem{true, SciMLBase.FullSpecialize}(sys, stateMap, [0.0, 130.0], parameterMap, jac=true)

# Parameter vector and initial value vector that crashes 
p = [26.56087782946684, 0.0011497569953977356, 0.0, 0.037649358067924674, 0.11497569953977356, 2.848035868435802, 1.291549665014884, 7.56463327554629, 0.02595024211399736, 0.0, 0.007054802310718645, 1.0, 0.0, 0.4, 1.0, 1.629750834620647e6, 1.25e-7, 0.0, 2.1544346900318843, 0.024770763559917114, 0.001747528400007683, 0.5336699231206307, 0.014174741629268055, 0.3511191734215131, 104.7615752789664, 1.0, 0.0016297508346206436, 1.232846739442066, 1.6297508346206435, 0.0657933224657568, 0.0, 0.005336699231206312, 0.0, 13.219411484660288, 0.275, 351119.17342151277, 0.003511191734215131, 0.6135907273413176, 0.1519911082952933]
u0 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.014174741629268055, 0.0657933224657568, 13.219411484660288, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

du0_2, dp_2 = Zygote.gradient(computeG, u0, p)

with (due to character limit) truncated output

[CVODES WARNING]  CVode
  Internal t = 120 and h = -2.39125e-15 are such that t + h = t on the next step. The solver will continue anyway.

ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [0]
Stacktrace:
  [1] getindex
    @ ./array.jl:924 [inlined]
 

Very peculiar. Can you open an issue with this MWE? I’ll need to dig in to find out what may be the difference here.

Yes. For the SciMLSensitivity.jl repo?

Yes