Error with maximum likelihood estimation using Optimization

using OptimizationOptimisers
using Optimization

loss(θ) = norm((predict(θ) - data_train) .* weights)
optf = OptimizationFunction((x, p) -> loss(x), AutoReverseDiff(true))
optprob = OptimizationProblem(optf, θ)
opt = Adam(0.05, (1.0, 0.985))  # η: learning rate
@time results = Optimization.solve(
    optprob, opt, callback = callback, maxiters = 180
)

gives error:

ERROR: Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)

Original code is here

Full code
include("make_OCFEM.jl")

"""
====================================================================
    Building OCFEM (orthogonal collocation on finite element method) 
        for z discretization with cubic hermite polynomials
====================================================================
""" 

begin
    Ne = 42 # Number of finite elements
    N = 2 # Number of interior collocation points
    n_components = 1  # Number of chemical species
    n_phases = 2 # 2 phases → 1 liquid(c profile) + 1 solid(q profile)
    p_order = 4 # Polynomial order + 1
    n_variables = n_components * n_phases * (p_order + 2 * Ne - 2)
    xₘᵢₙ = 0.0
    xₘₐₓ = 1.0 # z domain limits
    h = (xₘₐₓ - xₘᵢₙ) / Ne  # finite elements' sizes
end

H, A, B = make_OCFEM(Ne, n_phases, n_components) 
# H are the polynomial evaluations at collocation points, 
# A first derivative, B second derivative

#Building mass matrix
MM = BitMatrix(make_MM(Ne, n_phases, n_components)) # make mass matrix

"""
===================================================================
             Build UDE and solve & visualization    
===================================================================
"""
# -------------------Defining chromatography parameters---------------------
begin
	Qf = 5.0e-2                 # Flow rate (dm^3/min) 	
	d = 0.5                     # Column Diameter (dm)	
	L = 2.0                     # Bed Length (dm)       
	a = pi*d^2/4                # Column area (dm^2)
	ϵ = 0.5                     # Bed porosity 		    
	v = Qf/(a*ϵ)                # Interstitial velocity (dm/min)  
	Pe = 21.095632695978704     # Peclet Number
	Da_x = v*L/Pe                # Axial dispersion (dm^2/min)   
	c_in = 5.5                   # Feed concentration (mg/L)	   
	k_transf = 0.22             # Mass transfer coefficient (1/min) 
	k_iso  = 1.8                # Isotherm affinity parameter (L/mg)   
	qmax = 55.54                # Isotherm saturation parameter (mg/L)	
	q_test = qmax * k_iso * c_in^1.0 / (1.0 + k_iso * c_in^1.0) 
	# Scale parameter for amount adsorbed in solid phase 
    # Change exponent according to each isotherm: 1.0 for langmuir, 1.5 for sips in this work
end


#----------Define the derivative matrices stencil and node index-----------------
using LinearAlgebra

begin
    # Internal node index
    l_idx = 2 
    u_idx = 2Ne + 1
    # Boundary node index
    lb_idx = 1
    ub_idx = 2Ne + 2

	∂x = Array(A * inv(H))    # c(x, t) = H*x and ∂x(c(x, t)) = A*x = A*inv(H)*c(x, t)
	∂x² = Array(B * inv(H))   # c(x, t) = H*x and ∂x²(c(x, t)) = B*x = B*inv(H)*c(x, t)
end

#-------------------Importing experimental data-------------------
# Vermeulen’s kinetics and langmuir isotherm
using DelimitedFiles
c_exp_data = readdlm("traindata_improved_quad_lang_2min.csv", ',')

# ------------------Initializing Neural networks----------------------
using Random, Lux, ComponentArrays

nn = Chain(
  Dense(2, 22, tanh_fast),
  Dense(22, 1)
) 

ps, st = Lux.setup(Random.default_rng(), nn)

#---------------------building rhs function for DAE solver---------------------

mutable struct Hybrid{}
    L::Float64
    h::Float64
    v::Float64
    ϵ::Float64
    c_in::Float64
    ∂x::Matrix{Float64}
    ∂x²::Matrix{Float64}
    Da_x::Float64
    k_iso::Float64
    qmax::Float64
    q_test::Float64
end

using UnPack

function (Hybrid!::Hybrid)(du, u, p, t)
    #Aliasing parameters of chromatography model
    @unpack L, h, v, ∂x, ∂x², Da_x, ϵ, c_in, k_iso, qmax, q_test = Hybrid! 

    #---------------------Mass Transfer and equilibrium -----------------

    c = (@view u[lb_idx:ub_idx])  
    q = u[ub_idx+1:2*ub_idx] / q_test # scaling dependent variables
    q_eq = qmax * k_iso * c.^1.0 ./ (1.0 .+ k_iso * c.^1.0) / q_test# Change exponent according to each isotherm
    
    x1x2 = [q_eq q]'
    ∂x_u = ∂x*u
    ∂x²_u = ∂x²*u

    # Neural network output
    nn_out = nn(x1x2, p, st)[1]   # p of function Hybrid! stands for parameters of Neural network
    nn_internal = nn_out[2:end-1] 
    #------------------------Mass balance---------------------------
    
    # Concentration profile for internal node. du[2:u_idx] for ∂t(c(x, t))
    du[l_idx:u_idx] = - (1 - ϵ) / ϵ * nn_internal .- v/(h*L) * ∂x_u[l_idx:u_idx] .+ Da_x / (h*L)^2 * ∂x²_u[l_idx:u_idx]
    # Boundary node equations. left multiply with MM should be 0 (du[1] = du[ub_idx] = 0) 
    du[1] = Da_x / (h*L) * ∂x_u[1] - v * (c[1] - c_in)
    du[ub_idx] = ∂x_u[ub_idx] / (h*L)
   
    # Absorption profile (solid phase)
    du[ub_idx+1:2*ub_idx] = nn_out # du[ub_idx+1:end] for ∂t(q(x, t))
    nothing
end

using OrdinaryDiffEq
rhs = Hybrid(L, h, v, ϵ, c_in, ∂x, ∂x², Da_x, k_iso, qmax, q_test)
chroma_func = ODEFunction(rhs, mass_matrix = MM)

#-----------------------Build and Solve ODE Problem-------------------------
begin
    t0, t_end = first(c_exp_data[:, 1]), last(c_exp_data[: , 1])
    tspan = (t0, t_end)  
    datasize = size(c_exp_data, 1)
    t_steps = range(t0, t_end; length = datasize)
    ps_ca = ComponentArray(ps)
    c0 = 1e-3
    u0 = c0 * ones(n_variables)
    u0[Int(n_variables/2) + 1:end] .= qmax * k_iso * c0^1.0 / (1.0 + k_iso * c0^1.0)
end

chroma_prob = ODEProblem(chroma_func, u0, tspan, ps_ca)

LinearAlgebra.BLAS.set_num_threads(5)
@time sol_init = Array(solve(chroma_prob, FBDF(autodiff = false), saveat = t_steps)) 
# FBDF(autodiff = false) also works

using Plots
begin
    plot(t_steps, sol_init[Int(n_variables/2), :])
    scatter!(c_exp_data[:, 1], c_exp_data[:, 2])
end

"""
===================================================================
                    Training Neural Network   
===================================================================
"""

#-----------------------Training Neural Network----------------------
using SciMLSensitivity
function predict(p)
    sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))
    # Interpolating adjoint works well too

    prob_ = remake(chroma_prob; p = p)
    s_new = Array(solve(prob_, FBDF(autodiff = false), abstol = 5e-7, reltol = 5e-7, saveat = t_steps, sensealg = sensealg))

    s_new[Int(n_variables / 2), :] / c_in 
end

# Setting up training data

data_train = c_exp_data[:, 2] / c_in

weights = ones(size(data_train, 1))

#--------------------------Loss function---------------------------

function loss(p)
    loss = norm((predict(p) - data_train) .* weights)
    return loss
end

#-------------------------testing gradients-----------------------
p = copy(ps_ca)
@time loss(p)
@time predict(p)
using Zygote
@time grad_reverse = Zygote.gradient(loss, p)

#--------------------Maximum Likelihood estimation--------------------
using Optimization, DiffEqFlux

adtype = AutoZygote()
optf = OptimizationFunction((x, p) -> loss(x), adtype)
optprob = OptimizationProblem(optf, p)

iter = 1
callback1 = function(p, l)
    global iter 
    println(l)
    println(iter)
    iter += 1
    l < 1.0e-3
end

using OptimizationOptimisers

opt = Adam(0.05, (1.0, 0.985))  # η: learning rate

@time results = Optimization.solve(
    optprob, opt, callback = callback1, maxiters = 180
)

make_OCFEM.jl (6.5 KB)
traindata_improved_quad_lang_2min.csv

Hi! Could you try to reduce the complexity of the example, and make the corresponding code self-contained (with all imports and definitions)?

Is the updated code much clear with the upload make_OCFEM.jl and the data link? It’s just a DAE function which contains a neural network, with the experiment data to train the parameters of it.

Can you maybe give the full stack trace of the error?

Warning
┌ Warning: At t=0.0, dt was forced below floating point epsilon 5.0e-324, and step error estimate = 1.0. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase ~/.julia/packages/SciMLBase/sYmAV/src/integrator_interface.jl:623
Error
ERROR: Function argument passed to autodiff cannot be proven readonly.
If the the function argument cannot contain derivative data, instead call autodiff(Mode, Const(f), ...)
See https://enzyme.mit.edu/index.fcgi/julia/stable/faq/#Activity-of-temporary-storage for more information.
The potentially writing call is   store {} addrspace(10)* %.fca.0.0.0.0.1.extract, {} addrspace(10)** %.fca.0.0.0.0.1.gep, align 8, !dbg !69, !noalias !95, using   %.fca.0.0.0.0.1.gep = getelementptr inbounds { { { { { {} addrspace(10)*, {} addrspace(10)*, { [12 x [1 x [12 x double]]], [2 x {} addrspace(10)*] }, { { [1 x { {} addrspace(10)*, {} addrspace(10)* }], [3 x {} addrspace(10)*], {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, { {} addrspace(10)*, {} addrspace(10)*, i64 }, i8 }, { double, double }, i8, {} addrspace(10)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, { { [2 x i64], i64, [2 x i64], [2 x i64], i8, i8, [2 x i64] }, {} addrspace(10)*, {} addrspace(10)*, { i8 } }, { { {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8, {} addrspace(10)*, {} addrspace(10)*, i8 }, i8, i64, {} addrspace(10)*, i32 }, { {} addrspace(10)*, {} addrspace(10)* } }, [1 x {} addrspace(10)*] }, {} addrspace(10)*, i8, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, double, {} addrspace(10)* }, [1 x i8] }, { { { { { {} addrspace(10)*, {} addrspace(10)*, { [12 x [1 x [12 x double]]], [2 x {} addrspace(10)*] }, { { [1 x { {} addrspace(10)*, {} addrspace(10)* }], [3 x {} addrspace(10)*], {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, { {} addrspace(10)*, {} addrspace(10)*, i64 }, i8 }, { double, double }, i8, {} addrspace(10)*, { {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, { { [2 x i64], i64, [2 x i64], [2 x i64], i8, i8, [2 x i64] }, {} addrspace(10)*, {} addrspace(10)*, { i8 } }, { { {} addrspace(10)*, {} addrspace(10)* }, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, i8, {} addrspace(10)*, {} addrspace(10)*, i8 }, i8, i64, {} addrspace(10)*, i32 }, { {} addrspace(10)*, {} addrspace(10)* } }, [1 x {} addrspace(10)*] }, {} addrspace(10)*, i8, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, {} addrspace(10)*, double, {} addrspace(10)* }, [1 x i8] }* %.innerparm, i64 0, i32 0, i32 0, i32 0, i32 0, i32 1, !dbg !69
Stacktrace:
  [1] NonlinearFunction
    @ ~/.julia/packages/SciMLBase/sYmAV/src/scimlfunctions.jl:0 [inlined]
  [2] diffejulia_NonlinearFunction_93809_inner_27wrap
    @ ~/.julia/packages/SciMLBase/sYmAV/src/scimlfunctions.jl:0
  [3] macro expansion
    @ ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:5340 [inlined]
  [4] enzyme_call
    @ ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:4878 [inlined]
  [5] CombinedAdjointThunk
    @ ~/.julia/packages/Enzyme/QsaeA/src/compiler.jl:4750 [inlined]
  [6] autodiff
    @ ~/.julia/packages/Enzyme/QsaeA/src/Enzyme.jl:503 [inlined]
  [7] autodiff
    @ ~/.julia/packages/Enzyme/QsaeA/src/Enzyme.jl:524 [inlined]
  [8] value_and_pullback!(f!::NonlinearFunction{…}, y::Vector{…}, tx::Tuple{…}, prep::DifferentiationInterfaceEnzymeExt.EnzymeReverseTwoArgPullbackPrep{…}, backend::AutoEnzyme{…}, x::Vector{…}, ty::Tuple{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceEnzymeExt ~/.julia/packages/DifferentiationInterface/F5K7v/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl:123
  [9] pullback!
    @ ~/.julia/packages/DifferentiationInterface/F5K7v/src/first_order/pullback.jl:321 [inlined]
 [10] #9
    @ ~/.julia/packages/SciMLJacobianOperators/XTOmd/src/SciMLJacobianOperators.jl:312 [inlined]
 [11] (::SciMLJacobianOperators.JacobianOperator{…})(v::Vector{…}, u::Vector{…}, p::ComponentVector{…})
    @ SciMLJacobianOperators ~/.julia/packages/SciMLJacobianOperators/XTOmd/src/SciMLJacobianOperators.jl:140
 [12] (::LineSearch.var"#4#6")(du::Vector{…}, u::Vector{…}, fu::Vector{…}, p::ComponentVector{…})
    @ LineSearch ~/.julia/packages/LineSearch/Ky1ZB/src/utils.jl:44
 [13] #10
    @ ~/.julia/packages/LineSearch/Ky1ZB/src/backtracking.jl:78 [inlined]
 [14] #14
    @ ~/.julia/packages/LineSearch/Ky1ZB/src/backtracking.jl:94 [inlined]
 [15] solve!(cache::LineSearch.BackTrackingCache{…}, u::Vector{…}, du::Vector{…})
    @ LineSearch ~/.julia/packages/LineSearch/Ky1ZB/src/backtracking.jl:97
 [16] step!(cache::NonlinearSolveFirstOrder.GeneralizedFirstOrderAlgorithmCache{…}; recompute_jacobian::Nothing)
    @ NonlinearSolveFirstOrder ~/.julia/packages/NonlinearSolveFirstOrder/o91ZE/src/solve.jl:277
 [17] step!
    @ ~/.julia/packages/NonlinearSolveFirstOrder/o91ZE/src/solve.jl:224 [inlined]
 [18] #step!#163
    @ ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:253 [inlined]
 [19] step!
    @ ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:247 [inlined]
 [20] solve!(cache::NonlinearSolveFirstOrder.GeneralizedFirstOrderAlgorithmCache{…})
    @ NonlinearSolveBase ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:18
 [21] __solve(::NonlinearProblem{…}, ::NonlinearSolveFirstOrder.GeneralizedFirstOrderAlgorithm{…}; kwargs::@Kwargs{…})
    @ NonlinearSolveBase ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:6
 [22] macro expansion
    @ ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:173 [inlined]
 [23] __generated_polysolve(::NonlinearProblem{…}, ::NonlinearSolveBase.NonlinearSolvePolyAlgorithm{…}; stats::SciMLBase.NLStats, alias_u0::Bool, verbose::Bool, initializealg::NonlinearSolveBase.NonlinearSolveDefaultInit, kwargs::@Kwargs{…})
    @ NonlinearSolveBase ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:130
 [24] __generated_polysolve
    @ ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:130 [inlined]
 [25] #__solve#154
    @ ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:127 [inlined]
 [26] __solve
    @ ~/.julia/packages/NonlinearSolveBase/jA9TW/src/solve.jl:124 [inlined]
 [27] #solve_call#35
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:635 [inlined]
 [28] solve_call
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:592 [inlined]
 [29] #solve_up#44
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1112 [inlined]
 [30] solve_up
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1106 [inlined]
 [31] #solve#43
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1100 [inlined]
 [32] _initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, prob::ODEProblem{…}, alg::BrownFullBasicInit{…}, isinplace::Val{…})
    @ OrdinaryDiffEqNonlinearSolve ~/.julia/packages/OrdinaryDiffEqNonlinearSolve/s1Hh9/src/initialize_dae.jl:434
 [33] _initialize_dae!
    @ ~/.julia/packages/OrdinaryDiffEqCore/yTEch/src/initialize_dae.jl:56 [inlined]
 [34] initialize_dae!
    @ ~/.julia/packages/OrdinaryDiffEqCore/yTEch/src/initialize_dae.jl:40 [inlined]
 [35] initialize_dae!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/yTEch/src/initialize_dae.jl:40
 [36] __init(prob::ODEProblem{…}, alg::FBDF{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{…}, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Float64, reltol::Float64, qmin::Rational{…}, qmax::Int64, qsteady_min::Rational{…}, qsteady_max::Rational{…}, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias::ODEAliasSpecifier, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/yTEch/src/solve.jl:565
 [37] __init (repeats 5 times)
    @ ~/.julia/packages/OrdinaryDiffEqCore/yTEch/src/solve.jl:11 [inlined]
 [38] __solve(::ODEProblem{…}, ::FBDF{…}; kwargs::@Kwargs{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/yTEch/src/solve.jl:6
 [39] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/yTEch/src/solve.jl:1 [inlined]
 [40] solve_call(_prob::ODEProblem{…}, args::FBDF{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:635
 [41] solve_call
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:592 [inlined]
 [42] #solve_up#44
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1128 [inlined]
 [43] solve_up
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1106 [inlined]
 [44] #solve#42
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1043 [inlined]
 [45] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::QuadratureAdjoint{…}, alg::FBDF{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/quadrature_adjoint.jl:349
 [46] _adjoint_sensitivities
    @ ~/.julia/packages/SciMLSensitivity/7UEpc/src/quadrature_adjoint.jl:337 [inlined]
 [47] #adjoint_sensitivities#63
    @ ~/.julia/packages/SciMLSensitivity/7UEpc/src/sensitivity_interface.jl:401 [inlined]
 [48] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#323"{…})(Δ::ODESolution{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/7UEpc/src/concrete_solve.jl:627
 [49] ZBack
    @ ~/.julia/packages/Zygote/TWpme/src/compiler/chainrules.jl:212 [inlined]
 [50] (::Zygote.var"#kw_zpullback#56"{…})(dy::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/chainrules.jl:238
 [51] #295
    @ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:205 [inlined]
 [52] (::Zygote.var"#2169#back#297"{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
 [53] #solve#42
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1043 [inlined]
 [54] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [55] #295
    @ ~/.julia/packages/Zygote/TWpme/src/lib/lib.jl:205 [inlined]
 [56] (::Zygote.var"#2169#back#297"{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
 [57] solve
    @ ~/.julia/packages/DiffEqBase/HGITF/src/solve.jl:1033 [inlined]
 [58] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [59] predict
    @ ~/文件/julia/chromatography/Training the hybrid UDE-based model.jl:175 [inlined]
 [60] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface2.jl:0
 [61] loss
    @ ~/文件/julia/chromatography/Training the hybrid UDE-based model.jl:195 [inlined]
 [62] #22
    @ ~/文件/julia/chromatography/Training the hybrid UDE-based model.jl:211 [inlined]
 [63] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:91
 [64] withgradient(::Function, ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/TWpme/src/compiler/interface.jl:213
 [65] value_and_gradient
    @ ~/.julia/packages/DifferentiationInterface/F5K7v/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:91 [inlined]
 [66] value_and_gradient!(f::Function, grad::ComponentVector{…}, prep::DifferentiationInterface.NoGradientPrep, backend::AutoZygote, x::ComponentVector{…}, contexts::DifferentiationInterface.Constant{…})
    @ DifferentiationInterfaceZygoteExt ~/.julia/packages/DifferentiationInterface/F5K7v/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:105
 [67] (::OptimizationZygoteExt.var"#fg!#16"{…})(res::ComponentVector{…}, θ::ComponentVector{…})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/gvXsf/ext/OptimizationZygoteExt.jl:53
 [68] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:101 [inlined]
 [69] macro expansion
    @ ~/.julia/packages/Optimization/qX4vR/src/utils.jl:32 [inlined]
 [70] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/xC7Ic/src/OptimizationOptimisers.jl:83
 [71] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/sYmAV/src/solve.jl:187
 [72] solve(::OptimizationProblem{…}, ::Adam{…}; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/sYmAV/src/solve.jl:95
 [73] macro expansion
    @ ./timing.jl:581 [inlined]
Some type information was truncated. Use `show(err)` to see complete types.

I think @avikpal might be better suited to tackle this one? I’m not even sure how Enzyme comes into play

Taking a step back perhaps we should find what is producing a stateful closure.

@gdalle can you check where it comes from
(Ie DI, Optimization, etc)

From what I can tell, the closure is the SciMLBase.NonlinearFunction to which DI.pullback is applied?
I’ve taken great care to avoid creating closures in DI when Enzyme is involved.

The problem maybe related to AD method as described Here

Slight Detour: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with Enzyme.jl, we will make the switch.