Type error when using sciml_train with a stiff ODE solver

I’m using Julia since today so I might miss something fundamental. I recently discovered the UDE for SciML paper and the super nice SciML library. I try to rebuild something similar to the Lotka-Voltera example from the paper, but with a Stiff System (Oregonator), but fail to do so.

When running

using DifferentialEquations, Plots, DiffEqFlux, Flux, Optim, OrdinaryDiffEq, DiffEqSensitivity

build the true solution
tspan = (0.0, 360.0)
tsteps = 0.0:1.0:360.0
u0 = [1.0,2.0,3.0]

function orego(du,u,p,t)
s,q,w = p
y1,y2,y3 = u
du[1] = s*(y2+y1*(1-q*y1-y2))
du[2] = (y3-(1+y1)y2)/s
du[3] = w
(y1-y3)
end

p_true = [77.27,8.375e-6,0.161]
prob_true = ODEProblem(orego,u0,tspan,p_true)

sol = solve(prob_true,Rodas5(),saveat = tsteps)
true_sol = Array(sol)

plot(sol)

build the UDE
ann = FastChain(FastDense(3,16,tanh), FastDense(16,16,tanh), FastDense(16,1))
p1 = initial_params(ann)
p2 = [0.5,1, 0.]
p_init = [p1;p2]
theta = Flux.params(p_init)

function orego_ude!(du,u,p,t)
s,q,w = p[end-2:end]
y1,y2,y3 = u
du[1] = s*(y2+y1*(1-qy1-y2))
du[2] = ann(u,p[1:length(p1)])[1]
du[3] = w
(y1-y3)
end

prob_mod = ODEProblem(orego_ude!,u0,tspan,p_init)
sol_mod_test = Array(solve(prob_mod,Rodas4P(), u0=u0, p=p_init,saveat = tsteps, abstol=1e-12, reltol=1e-12,sensealg = InterpolatingAdjoint(checkpointing=true)))

ml part
function predict_mod(theta)
return Array(solve(prob_mod,Rodas4P(), u0=u0, p=p_init,saveat = tsteps, abstol=1e-12, reltol=1e-12,sensealg = InterpolatingAdjoint(checkpointing=true)))
end

loss(theta) = sum(abs2, predict_mod(theta).-true_sol)
l = loss(p_init)

cb = function (theta,l)
println(l)
return false
end

result_mod = DiffEqFlux.sciml_train(loss, p_init, ADAM(0.01), cb = cb,maxiters = 200)

Julia throws a huge stack on the last line, beginning with

LoadError: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing,Float64,12}

in expression starting at /somepath/oregonator_ude.jl:57

Which confuses me quite a bit in many ways, the first being me not understanding Julia’s approach to types: I was assuming that one does not have to take care of explicitly defining types.
Also I do not understand why sol_mod_test gets assigned without problem, the same line later however fails as long as it contains a stiff solver and or the sensealg argument (With Rodas5() it works but reaches max_iter, which I kinda expect due to the nature of the System). Due to the behaviour I suspect it has to do something with the ODE solver evaluating Jacobians (or not) during the sensitivity calculations but I don’t see what I have to change, even after consulting the documentation.

The problem isn’t in the Jacobian calculation but it’s something weird with the special Rosenbrock interpolations. I’ll open an issue to look at it deeper, but it’s not the real issue here. The real issue is that you won’t want to use a Rosenbrock method on this. Looking at the forward solve it might seem like a good idea, but remember that with N states and P parameters, the forward solve has size N and the reverse solve has size N+P. Rosenbrock methods are only recommended for <100 or so ODEs, and otherwise they get quite slow due to lack of Jacobian reuse. Since the reverse solve is ~350 ODEs, TRBDF2 will do much better (and indeed, if you run Rodas4P(autodiff=false) you see that it’s slow, and it’s not the autodiff’s fault there since a full numerical Jacobian is almost the same cost as the forward mode one).

So TL;DR, just use TRBDF2 or KenCarp4 here, like:

#ML part
function predict_mod(theta)
    return Array(solve(prob_mod,TRBDF2(), u0=u0, p=p_init,saveat = tsteps, abstol=1e-8, reltol=1e-8,sensealg = InterpolatingAdjoint(checkpointing=true)))
end

loss(theta) = sum(abs2, predict_mod(theta).-true_sol)
l = loss(p_init)

cb = function (theta,l)
    println(l)
    return false
end

result_mod = DiffEqFlux.sciml_train(loss, p_init, ADAM(0.01), cb = cb,maxiters = 200)

That said, you’re going to have to play with some extra tricks to get something this stiff to fit. A random neural network likely explodes over the time span, so you’ll need to at least do https://diffeqflux.sciml.ai/dev/examples/local_minima/ to get something going. You’ll want to weight by relative errors and stuff like that as well.

Many thanks for your fast and concise answer, it indeed works with the solvers you mentioned.
Coming from mainly Python, Fortran and Matlab I must admit I’m a little bit overwhelmed with the choice of solvers.
I also did not expect an answer from the main author of the paper, but once your here let me compliment you on this amazing library! I was always a little hesitant towards Julia because I doubted its maturity but was quite amazed by the features & state of SciML after I gave it a go this week. Many thanks & I hope I can contribute in the future once I have a better understanding of the language.