Type stability for ODEProblem objects

I am working on a problem where I simulate an ODE as part of a larger optimization scheme (the ODE becomes part of the cost function computation in a derivative-free setting), and I was trying to track down some inefficiencies in the code that was causing it to have what seems like excessive runtime. I looked at the solver wrapper that I am using, and it appears that it is having some type instability when creating the ODEProblem and the return type from the solve routine.

I tried this on the simple Lorenz demo from the docs, and it is also showing type instability. Is there a better way to create these objects so that the compiler is able to assign concrete types to them and not Any?

Here is the code I used for the Lorenz sample:

using DifferentialEquations

function lorenz!(du,u,p,t)
 du[1] = 10.0*(u[2]-u[1])
 du[2] = u[1]*(28.0-u[3]) - u[2]
 du[3] = u[1]*u[2] - (8/3)*u[3]
end


function lorenzSim!()
    u0 = [1.0;0.0;0.0]
    tspan = (0.0,50.0)
    prob = ODEProblem(lorenz!,u0,tspan)
    sol = solve(prob)
    u0 = sol[end];

    tspan = (50.0,100.0)
    prob = ODEProblem(lorenz!,u0,tspan)
    sol = solve(prob)
    return sol[end]
end

Here is the type warning output, note that prob, sol, and u0 are all Any.

julia> @code_warntype lorenzSim!([1.0;0.0;0.0])
Variables
  #self#::Core.Compiler.Const(lorenzSim!, false)
  u::Array{Float64,1}
  u0::Any
  tspan::Tuple{Float64,Float64}
  prob::Any
  sol::Any

Body::Any
1 ─       (u0 = u)
│         (tspan = Core.tuple(0.0, 50.0))
│         (prob = Main.ODEProblem(Main.lorenz!, u0::Array{Float64,1}, tspan::Core.Compiler.Const((0.0, 50.0), false)))
│         (sol = Main.solve(prob))
│   %5  = sol::Any
│   %6  = Base.lastindex(sol)::Any
│         (u0 = Base.getindex(%5, %6))
│         (tspan = Core.tuple(50.0, 100.0))
│         (prob = Main.ODEProblem(Main.lorenz!, u0, tspan::Core.Compiler.Const((50.0, 100.0), false)))
│         (sol = Main.solve(prob))
│   %11 = sol::Any
│   %12 = Base.lastindex(sol)::Any
│   %13 = Base.getindex(%11, %12)::Any
└──       return %13
1 Like

https://github.com/SciML/DiffEqBase.jl/pull/570 I just fixed the inference issue today. One last thing is that you might want to directly choose iip vs oop, i.e.

prob = ODEProblem{true}(lorenz!,u0,tspan)

to know it’s in-place instead of relying on the method table inference there.

1 Like

Try to initialize the problem while explicitly configuring whther it is in place or out of place: ODEProblem{true} or false.

Thanks. I gave that a try, and it fixes the inference for the problem itself, but the solution is still inferred to Any. The inference of the solution then affects the subsequent problem definition, causing it to be improperly inferred (since it thinks the initial point is coming in as type Any). Does the PR you mentioned require 1.6 to work properly? I have a copy of the master branch I keep around that I don’t often use that I could try this with - but when I tried it on my 1.5.0 version it didn’t help.

Updated code:

using DifferentialEquations

function lorenz!(du,u,p,t)
 du[1] = 10.0*(u[2]-u[1])
 du[2] = u[1]*(28.0-u[3]) - u[2]
 du[3] = u[1]*u[2] - (8/3)*u[3]
end


function lorenzSim!(u::Array{Float64,1})
    tspan = (0.0,50.0)
    prob1 = ODEProblem{true}(lorenz!,u,tspan)
    sol1 = solve(prob1)

    u1 = sol1[end]

    tspan = (50.0,100.0)
    prob2 = ODEProblem{true}(lorenz!,u1,tspan)
    sol2 = solve(prob2)
    return sol2[end]
end

New type results:

julia> @code_warntype lorenzSim!([1.0;0.0;0.0])
Variables
  #self#::Core.Compiler.Const(lorenzSim!, false)
  u::Array{Float64,1}
  prob1::ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,DiffEqBase.NullParameters,ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}
  sol1::Any
  u1::Any
  tspan::Tuple{Float64,Float64}
  prob2::ODEProblem{_A,Tuple{Float64,Float64},true,DiffEqBase.NullParameters,ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem} where _A
  sol2::Any

Body::Any
1 ─       (tspan = Core.tuple(0.0, 50.0))
│   %2  = Core.apply_type(Main.ODEProblem, true)::Core.Compiler.Const(ODEProblem{true,tType,isinplace,P,F,K,PT} where PT where K where F where P where isinplace where tType, false)
│         (prob1 = (%2)(Main.lorenz!, u, tspan::Core.Compiler.Const((0.0, 50.0), false)))
│         (sol1 = Main.solve(prob1::Core.Compiler.PartialStruct(ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,DiffEqBase.NullParameters,ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem}, Any[Core.Compiler.Const(ODEFunction{true,typeof(lorenz!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}(lorenz!, LinearAlgebra.UniformScaling{Bool}(true), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing), false), Array{Float64,1}, Core.Compiler.Const((0.0, 50.0), false), Core.Compiler.Const(DiffEqBase.NullParameters(), false), Core.Compiler.Const(Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}(), false), Core.Compiler.Const(DiffEqBase.StandardODEProblem(), false)])))
│   %5  = sol1::Any
│   %6  = Base.lastindex(sol1)::Any
│         (u1 = Base.getindex(%5, %6))
│         (tspan = Core.tuple(50.0, 100.0))
│   %9  = Core.apply_type(Main.ODEProblem, true)::Core.Compiler.Const(ODEProblem{true,tType,isinplace,P,F,K,PT} where PT where K where F where P where isinplace where tType, false)
│   %10 = u1::Any
│         (prob2 = (%9)(Main.lorenz!, %10, tspan::Core.Compiler.Const((50.0, 100.0), false)))
│         (sol2 = Main.solve(prob2))
│   %13 = sol2::Any
│   %14 = Base.lastindex(sol2)::Any
│   %15 = Base.getindex(%13, %14)::Any
└──       return %15


You’ll need to wait until that PR is merged, which should be in a few hours.

The fix is released.

2 Likes

Thanks, but it appears that the inference is only working if you specify the algorithm in the solve command, and even then only for certain solvers. The original example I posted doesn’t give a full inference for the solution (it is still Any), but changing the solves to have solve(prob1, Tsit5()) gives full inference, but using something else like solve(prob1, Rosenbrock23()) doesn’t infer the return type anymore.

Is there a list of the algorithms that will allow a fully-specified return type?

Edit: This is the updated code from comment 3, not the original code - you still have to specify ODEProblem{true} to get the problem to infer fully.

Open an issue. This was unknown and likely due to factorization. But note that none of this should effect your downstream code: just use a function barrier. The main reason to fix this is just compile times.

Hi Chris,

I’m reviving this as I have a similar case use and a function barrier doesn’t seem to improve timing. I’ve followed the example above, an ODE solution is passed to the cost functions scalarmeasure and my_scalarmeasure.

Efficiently getting the gradients is important if this is then to be looped many many times (and for a more complicated system), e.g. when doing gradient descent.

using DifferentialEquations, BenchmarkTools

function lorenz!(du,u,p,t)
 du[1] = 10.0*(u[2]-u[1])
 du[2] = u[1]*(28.0-u[3]) - u[2]
 du[3] = u[1]*u[2] - (8/3)*u[3]
end

function lorenzSim!(u::Vector{T}) where {T<:Real}
    tspan = (0.0,100.0)
    prob1 = ODEProblem{true}(lorenz!,u,tspan)
    solve(prob1; save_idxs=[3])
end

function typestable_lorenzSim!(u::Vector{T}) where {T<:Real}
    tspan = (0.0, 100.0)
    prob1 = ODEProblem{true}(lorenz!, u, tspan)
    solve(prob1, Tsit5(); save_idxs=[3])
end

scalarmeasure(u::Vector{T}) where {T<:Real} = mean(lorenzSim!(u))
my_scalarmeasure(u::Vector{T}) where {T<:Real} = mean(typestable_lorenzSim!(u))

u = [1.0; 0.0; 0.0]

scalarmeasure(u)
my_scalarmeasure(u)
@time scalarmeasure(u)
@time my_scalarmeasure(u)

function grad(u::Vector{T}) where {T<:Real}
    out = similar(u)
    ForwardDiff.gradient!(out, scalarmeasure, u)
end

function my_grad(u::Vector{T}) where {T<:Real}
    out = similar(u)
    ForwardDiff.gradient!(out, my_scalarmeasure, u)
end

grad(u);
my_grad(u);
@btime grad(u);
@btime my_grad(u);

Ouput using Julia Version 1.7.2

  0.000522 seconds (5.70 k allocations: 421.531 KiB)
  0.000671 seconds (5.63 k allocations: 408.734 KiB)
  4.487 ms (28481 allocations: 1.65 MiB)
  1.197 ms (12160 allocations: 1.24 MiB)

How could I make functions like scalarmeasure type stable if I didn’t know beforehand what solver makes the solution for a general ODEProblem (again, in my case this is a more complex conductance-based neuron system) type stable ?
I’m guessing this is the reason why we see twice the number of allocations and the increased time.

Thanks!

(a) this is completely unrelated to the thread, so please open a new thread for this kind of thing, (b) it’s being solved today Redesign default ODE solver to be type-grounded and lazy by oscardssmith · Pull Request #2184 · SciML/OrdinaryDiffEq.jl · GitHub

Hey, I agree that the scope of my question goes a bit beyond that of the original thread. But essentially I thought the question was the same: how to know which solver to use, or in other words

Thanks for pointing to the PR, I couldn’t find where/if there was an issue on this already but it’s good to know it’s flagged! I’ll check once it’s fixed :slight_smile: