Remake problem with modeltoolkitized ODE function

Hi, I have a model that follows more or less the reccomendations of the “Solving Large Stiff Equations” tutorial that I was trying to optimize the parameters for.
However I have encountered a problem with the remake applied to the modeltoolkitized version of the code.

If for example I have my ODEFunction defined the conventional way, things seem to be working fine, and the model is updated with the new parameters and works apparently fine.

jac_sparsity = Symbolics.jacobian_sparsity((du, u) -> ADEST!(du, u, p0, 1),
    du0, u0) # add the sparsity pattern to speed up the solution
rhs! = ODEFunction(ADEST!, jac_prototype=jac_sparsity)
# Solving the ODE
prob = ODEProblem(rhs!, u0, tspan, p0)
sol = solve(prob, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)
function remake_prob(p)
    u0 = zeros(size(x_points)[1], 5).+1e-16
    u0[:, 4] .= (p0[1]*p0[11]/p0[12]-p0[2])*p0[16]
    u0[:, 5] .= (p0[1]*p0[7]/p0[13]-p0[2])*p0[16]
    prob_ = remake(prob, u0 = u0, p=p)
    sol = solve(prob_, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)
    return sol
end

# testing the remake function
sol_ = remake_prob(p_)

However, if the model is the modelingtoolkitzed before then when I try to solve the updated model with remake the model just doesnt converge (MaxIters error).

@mtkbuild de = modelingtoolkitize(prob)
prob_mtk = ODEProblem(de, [], tspan, jac=true, sparse=true)
sol_mtk = solve(prob_mtk, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10) # This works fine! Proof:

mtk_u = [sol_mtk.u[i][:] for i in eachindex(sol_mtk.u)]
u = [sol.u[i][:] for i in eachindex(sol.u)]
u ≈ mtk_u #true

function remake_mtk(p)
    u0 = zeros(size(x_points)[1], 5).+1e-16
    u0[:, 4] .= (p0[1]*p0[11]/p0[12]-p0[2])*p0[16]
    u0[:, 5] .= (p0[1]*p0[7]/p0[13]-p0[2])*p0[16]
    prob_ = remake(prob_mtk, u0=u0, p=p)
    sol = solve(prob_, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)
    return sol
end

sol_mtk_ = remake_mtk(p_)  #MaxIters, doeant converge.

Am I making some rookie mistake here? Should I treat the modeltoolkitize version differently?

Thank you!

Full reproducible script:

using DifferentialEquations
using ForwardDiff
using Sundials
using Symbolics
using ModelingToolkit


function create_ADEST(v, De, dx, c_in, nmob)
    # Defining the reaction rates model
    function monod_nitrate(c_no3, c_doc, c_o2, B_no3, mu_max_no3, K_no3, K_doc_no3, yield_no3, I_o2)
        return -mu_max_no3/yield_no3 .* c_no3 ./ (K_no3 .+ c_no3) .* c_doc ./ (K_doc_no3 .+ c_doc) .* I_o2./(I_o2 .+ c_o2) .* B_no3
    end

    function monod_oxygen(c_o2, c_doc, B_o2, mu_max_o2, K_o2, K_doc_o2, yield_o2)
        return -mu_max_o2/yield_o2 .* c_o2 ./ (K_o2 .+ c_o2) .* c_doc ./ (K_doc_o2 .+ c_doc) .* B_o2
    end

    function bac_rate_o2(c_o2, c_doc, B_o2, mu_max_o2, K_o2, K_doc_o2, decay_o2)
        return (mu_max_o2 .* c_o2 ./ (K_o2 .+ c_o2) .* c_doc ./ (K_doc_o2 .+ c_doc) .- decay_o2) .* B_o2
    end

    function bac_rate_no3(c_no3, c_doc, c_o2, B_no3, mu_max_no3, K_no3, K_doc_no3, decay_no3, I_o2)
        return (mu_max_no3 .* c_no3 ./ (K_no3 .+ c_no3) .* c_doc ./ (K_doc_no3 .+ c_doc) .* I_o2 ./ (I_o2 .+ c_o2) .- decay_no3) .* B_no3
    end

    function doc_rate(c_o2, c_no3, c_doc, B_o2, B_no3, alpha, Kb,
         mu_max_no3, K_no3, yield_doc_no3, K_doc_no3,I_o2,
         mu_max_o2, K_o2, yield_doc_o2, K_doc_o2)
        return alpha .* (B_o2 ./ (Kb .+ B_o2) .+ B_no3 ./ (Kb .+ B_no3))  - 
            mu_max_no3/yield_doc_no3 .* c_no3 ./ (K_no3 .+ c_no3) .* c_doc ./ (K_doc_no3 .+ c_doc) .* I_o2 ./(I_o2 .+ c_o2) .* B_no3 -
            mu_max_o2/yield_doc_o2 .* c_o2 ./ (K_o2 .+ c_o2) .* c_doc ./ (K_doc_o2 .+ c_doc) .* B_o2
    end

    function ADEST!(du, u, p ,t)
        alpha = p[1]
        Kb = p[2]
        mu_max_no3 = p[3]
        K_no3 = p[4]
        I_o2 = p[5]
        yield_no3 = p[6]
        yield_doc_no3 = p[7]
        mu_max_o2 = p[8]
        K_o2 = p[9]
        yield_o2 = p[10]
        yield_doc_o2 = p[11]
        decay = p[12]
        decay_no3 = p[13]
        K_doc_no3 = p[14]
        K_doc_o2 = p[15]

        # transport
        c_advec = [c_in;u[:,1:nmob]]
        advec = -v .* diff(c_advec, dims=1) ./ dx
        gradc=diff(u[:,1:nmob], dims=1)/dx
        disp = ([gradc; zeros(1, nmob)]-[zeros(1, nmob); gradc]).*De

        du[:,1] .= advec[:,1] .+ disp[:,1] .+ monod_oxygen(u[:,1], u[:,2], u[:,4], mu_max_o2, K_o2, K_doc_o2, yield_o2)
        du[:,2] .= advec[:,2] .+ disp[:,2] .+ doc_rate(u[:,1], u[:,3], u[:,2], u[:,4], u[:,5], alpha, Kb,
         mu_max_no3, K_no3, yield_doc_no3, K_doc_no3,I_o2,
         mu_max_o2, K_o2, yield_doc_o2, K_doc_o2)
        du[:,3] .= advec[:,3] .+ disp[:,3] .+ monod_nitrate(u[:,3], u[:,2], u[:,1], u[:,5], mu_max_no3, K_no3, K_doc_no3, yield_no3, I_o2)
        du[:,4] .= bac_rate_o2(u[:,1], u[:,2], u[:,4], mu_max_o2, K_o2, K_doc_o2, decay)
        du[:,5] .= bac_rate_no3(u[:,3], u[:,2], u[:,1], u[:,5], mu_max_no3, K_no3, K_doc_no3, decay_no3, I_o2)
    nothing
    end
    return ADEST!
end
v = 1e-6
De = [1e-7 1e-7 1e-7]
c_in = [1e-3 1e-8 1e-6]
nmob = 3
dx = 0.01
ADEST! = create_ADEST(v, De, dx, c_in, nmob) # Creating the ADEST function

# Initial parameters:
p0 = [
        6e-10,  # alpha
        1e-5,  # Kb
        2e-4,  # mu_no3
        3.7e-5,  # K_no3
        1e-7,  # I_o2
        0.024,  # yield_no3
        0.30,  # yield_doc_no3
        5.3e-4,  # mu_o2
        7.5e-5,  # K_o2
        0.032,  # yield_o2
        0.44,  # yield_doc_o2
        5e-6,  # dec-rate_o2
        5e-6,  # dec-rate_no3
        1e-6,  # K_doc_no3
        7e-6,  # K_doc_o2
        0.05, # starting ss ratio
        ]

# u0
len = 0.5
x_points = 0+dx/2:dx:len-dx/2
u0 = zeros(size(x_points)[1], 5).+1e-16
u0[:, 4] .= (p0[1]*p0[11]/p0[12]-p0[2])*p0[16]
u0[:, 5] .= (p0[1]*p0[7]/p0[13]-p0[2])*p0[16]

# ODE Problem
tspan = (0.0, 100.0*3600)
prob = ODEProblem(ADEST!, u0, tspan, p0)
# testing forward mode autodiff:
du0 = zeros(size(u0))
ForwardDiff.jacobian((du, u) -> ADEST!(du, u, p0, 1), du0, u0) # YES!

jac_sparsity = Symbolics.jacobian_sparsity((du, u) -> ADEST!(du, u, p0, 1),
    du0, u0) # add the sparsity pattern to speed up the solution
rhs! = ODEFunction(ADEST!, jac_prototype=jac_sparsity)
# testing function with the sparsity pattern
rhs!(du0, u0, p0, 0)
ForwardDiff.jacobian((du, u) -> rhs!(du, u, p0, 1), du0, u0) # YES!

# Solving the ODE
prob = ODEProblem(rhs!, u0, tspan, p0)
sol = solve(prob, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)

# testing the function with ModelingToolkit
@mtkbuild de = modelingtoolkitize(prob)
prob_mtk = ODEProblem(de, [], tspan, jac=true, sparse=true)
sol_mtk = solve(prob_mtk, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)

mtk_u = [sol_mtk.u[i][:] for i in eachindex(sol_mtk.u)]
u = [sol.u[i][:] for i in eachindex(sol.u)]
u ≈ mtk_u #true

# testing remake:
p_ = [
    1e-10,  # alpha
    1e-7,  # Kb
    1e-5,  # mu_no3
    1e-7,  # K_no3
    1e-10,  # I_o2
    0.001,  # yield_no3
    0.4,  # yield_doc_no3
    1e-5,  # mu_o2
    1e-8,  # K_o2
    0.001,  # yield_o2
    0.1,  # yield_doc
    1e-10,  # dec-rate_o2
    1e-10,  # dec-rate_no3
    1e-7,  # K_doc_no3
    1e-7,  # K_doc_o2
    0.01,
]

# remake with functions:

function remake_prob(p)
    u0 = zeros(size(x_points)[1], 5).+1e-16
    u0[:, 4] .= (p0[1]*p0[11]/p0[12]-p0[2])*p0[16]
    u0[:, 5] .= (p0[1]*p0[7]/p0[13]-p0[2])*p0[16]
    prob_ = remake(prob, u0 = u0, p=p)
    sol = solve(prob_, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)
    return sol
end

function remake_mtk(p)
    u0 = zeros(size(x_points)[1], 5).+1e-16
    u0[:, 4] .= (p0[1]*p0[11]/p0[12]-p0[2])*p0[16]
    u0[:, 5] .= (p0[1]*p0[7]/p0[13]-p0[2])*p0[16]
    prob_ = remake(prob_mtk, u0=u0, p=p)
    sol = solve(prob_, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)
    return sol
end


# testing the remake function
sol_ = remake_prob(p_)

sol_mtk_ = remake_mtk(p_)

# testign remake without functions:
u0 = zeros(size(x_points)[1], 5).+1e-16
u0[:, 4] .= (p0[1]*p0[11]/p0[12]-p0[2])*p0[16]
u0[:, 5] .= (p0[1]*p0[7]/p0[13]-p0[2])*p0[16]

prob_ = remake(prob, u0 = u0, p=p_)
sol_2 = solve(prob_, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)

# Remake seems to be working fine:
sol.u ≈ sol_.u #false
sol_.u ≈ sol_2.u #true


prob_mtk_ = remake(prob_mtk, u0 = u0, p=p_)
sol_mtk_2 = solve(prob_mtk_, Rosenbrock23(autodiff=true), saveat=60*60, reltol=1e-10, abstol=1e-10)