Type stability for ODEProblem objects

I am working on a problem where I simulate an ODE as part of a larger optimization scheme (the ODE becomes part of the cost function computation in a derivative-free setting), and I was trying to track down some inefficiencies in the code that was causing it to have what seems like excessive runtime. I looked at the solver wrapper that I am using, and it appears that it is having some type instability when creating the ODEProblem and the return type from the solve routine.

I tried this on the simple Lorenz demo from the docs, and it is also showing type instability. Is there a better way to create these objects so that the compiler is able to assign concrete types to them and not Any?

Here is the code I used for the Lorenz sample:

using DifferentialEquations

function lorenz!(du,u,p,t)
 du[1] = 10.0*(u[2]-u[1])
 du[2] = u[1]*(28.0-u[3]) - u[2]
 du[3] = u[1]*u[2] - (8/3)*u[3]
end


function lorenzSim!()
    u0 = [1.0;0.0;0.0]
    tspan = (0.0,50.0)
    prob = ODEProblem(lorenz!,u0,tspan)
    sol = solve(prob)
    u0 = sol[end];

    tspan = (50.0,100.0)
    prob = ODEProblem(lorenz!,u0,tspan)
    sol = solve(prob)
    return sol[end]
end

Here is the type warning output, note that prob, sol, and u0 are all Any.

julia> @code_warntype lorenzSim!([1.0;0.0;0.0])
Variables
  #self#::Core.Compiler.Const(lorenzSim!, false)
  u::Array{Float64,1}
  u0::Any
  tspan::Tuple{Float64,Float64}
  prob::Any
  sol::Any

Body::Any
1 ─       (u0 = u)
β”‚         (tspan = Core.tuple(0.0, 50.0))
β”‚         (prob = Main.ODEProblem(Main.lorenz!, u0::Array{Float64,1}, tspan::Core.Compiler.Const((0.0, 50.0), false)))
β”‚         (sol = Main.solve(prob))
β”‚   %5  = sol::Any
β”‚   %6  = Base.lastindex(sol)::Any
β”‚         (u0 = Base.getindex(%5, %6))
β”‚         (tspan = Core.tuple(50.0, 100.0))
β”‚         (prob = Main.ODEProblem(Main.lorenz!, u0, tspan::Core.Compiler.Const((50.0, 100.0), false)))
β”‚         (sol = Main.solve(prob))
β”‚   %11 = sol::Any
β”‚   %12 = Base.lastindex(sol)::Any
β”‚   %13 = Base.getindex(%11, %12)::Any
└──       return %13
1 Like

https://github.com/SciML/DiffEqBase.jl/pull/570 I just fixed the inference issue today. One last thing is that you might want to directly choose iip vs oop, i.e.

prob = ODEProblem{true}(lorenz!,u0,tspan)

to know it’s in-place instead of relying on the method table inference there.

1 Like

Try to initialize the problem while explicitly configuring whther it is in place or out of place: ODEProblem{true} or false.

Thanks. I gave that a try, and it fixes the inference for the problem itself, but the solution is still inferred to Any. The inference of the solution then affects the subsequent problem definition, causing it to be improperly inferred (since it thinks the initial point is coming in as type Any). Does the PR you mentioned require 1.6 to work properly? I have a copy of the master branch I keep around that I don’t often use that I could try this with - but when I tried it on my 1.5.0 version it didn’t help.

Updated code:

using DifferentialEquations

function lorenz!(du,u,p,t)
 du[1] = 10.0*(u[2]-u[1])
 du[2] = u[1]*(28.0-u[3]) - u[2]
 du[3] = u[1]*u[2] - (8/3)*u[3]
end


function lorenzSim!(u::Array{Float64,1})
    tspan = (0.0,50.0)
    prob1 = ODEProblem{true}(lorenz!,u,tspan)
    sol1 = solve(prob1)

    u1 = sol1[end]

    tspan = (50.0,100.0)
    prob2 = ODEProblem{true}(lorenz!,u1,tspan)
    sol2 = solve(prob2)
    return sol2[end]
end

New type results:

julia> @code_warntype lorenzSim!([1.0;0.0;0.0])
Variables
  #self#::Core.Compiler.Const(lorenzSim!, false)
  u::Array{Float64,1}
  prob1::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,DiffEqBase.NullParameters,ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}
  sol1::Any
  u1::Any
  tspan::Tuple{Float64,Float64}
  prob2::ODEProblem{_A,Tuple{Float64,Float64},true,DiffEqBase.NullParameters,ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem} where _A
  sol2::Any

Body::Any
1 ─       (tspan = Core.tuple(0.0, 50.0))
β”‚   %2  = Core.apply_type(Main.ODEProblem, true)::Core.Compiler.Const(ODEProblem{true,tType,isinplace,P,F,K,PT} where PT where K where F where P where isinplace where tType, false)
β”‚         (prob1 = (%2)(Main.lorenz!, u, tspan::Core.Compiler.Const((0.0, 50.0), false)))
β”‚         (sol1 = Main.solve(prob1::Core.Compiler.PartialStruct(ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,DiffEqBase.NullParameters,ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}, Any[Core.Compiler.Const(ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}(lorenz!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing), false), Array{Float64,1}, Core.Compiler.Const((0.0, 50.0), false), Core.Compiler.Const(DiffEqBase.NullParameters(), false), Core.Compiler.Const(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}(), false), Core.Compiler.Const(DiffEqBase.StandardODEProblem(), false)])))
β”‚   %5  = sol1::Any
β”‚   %6  = Base.lastindex(sol1)::Any
β”‚         (u1 = Base.getindex(%5, %6))
β”‚         (tspan = Core.tuple(50.0, 100.0))
β”‚   %9  = Core.apply_type(Main.ODEProblem, true)::Core.Compiler.Const(ODEProblem{true,tType,isinplace,P,F,K,PT} where PT where K where F where P where isinplace where tType, false)
β”‚   %10 = u1::Any
β”‚         (prob2 = (%9)(Main.lorenz!, %10, tspan::Core.Compiler.Const((50.0, 100.0), false)))
β”‚         (sol2 = Main.solve(prob2))
β”‚   %13 = sol2::Any
β”‚   %14 = Base.lastindex(sol2)::Any
β”‚   %15 = Base.getindex(%13, %14)::Any
└──       return %15


You’ll need to wait until that PR is merged, which should be in a few hours.

The fix is released.

2 Likes

Thanks, but it appears that the inference is only working if you specify the algorithm in the solve command, and even then only for certain solvers. The original example I posted doesn’t give a full inference for the solution (it is still Any), but changing the solves to have solve(prob1, Tsit5()) gives full inference, but using something else like solve(prob1, Rosenbrock23()) doesn’t infer the return type anymore.

Is there a list of the algorithms that will allow a fully-specified return type?

Edit: This is the updated code from comment 3, not the original code - you still have to specify ODEProblem{true} to get the problem to infer fully.

Open an issue. This was unknown and likely due to factorization. But note that none of this should effect your downstream code: just use a function barrier. The main reason to fix this is just compile times.