Type-instability in DifferentialEquations.jl

Hi everyone, I’m trying to solve ODE many times each with different initial conditions, and I would like to optimize my code for that.
I’ve tried to save in a (mutable) struct the solver in order to save time and the need to create each time the solver.

My code goes something like

module MWE

using DifferentialEquations

mutable struct SaveODE
    integrator::OrdinaryDiffEq.ODEIntegrator
    sol::ODESolution
    tspan::Tuple{Float64,Float64}
    param::Tuple{Float64,Float64}
    result::Array{Float64,1}
    result2::Array{Float64,1}
    SaveODE() = new()
end

function initilaize_ode(save::SaveODE, psi0::Array{Float64,1})
    prob = ODEProblem(flow_eqs!, psi0, save.tspan, save.param)
    integrator = init(prob, Tsit5())
    save.integrator = integrator
    save.sol = solve!(integrator)
    
    save.result = integrator.u
end

function solve_ode(save::SaveODE, psi0::Array{Float64,1})
    reinit!(save.integrator, psi0)

    sol::ODESolution = solve!(save.integrator)
    save.result2 = save.integrator.u
    sol
end

function flow_eqs!(dpsi::Array{Float64,1}, psi::Array{Float64,1}, param::Tuple{Float64,Float64}, t::Float64)
    dpsi .= param[1] .* psi
end

end
using .MWE
using Plots
new_save = MWE.SaveODE()
new_save.param = (-1.0,0.0)
new_save.tspan = (0.0, 5.0)
MWE.initilaize_ode(new_save,ones(Float64,4))
sol = MWE.solve_ode(new_save,2*ones(Float64,4))
plot(sol)

and if I try to analyze my code I get more allocations than I expected and red warning in @code_warntype:

julia> @time MWE.solve_ode(new_save,ones(Float64,4));
0.000029 seconds (111 allocations: 12.656 KiB)


julia> @code_warntype MWE.solve_ode(new_save,ones(Float64,4));
Variables
  #self#::Core.Const(Main.MWE.solve_ode)
  save::Main.MWE.SaveODE
  psi0::Vector{Float64}
  sol::SciMLBase.ODESolution

Body::SciMLBase.ODESolution
1 ─ %1 = Base.getproperty(save, :integrator)::OrdinaryDiffEq.ODEIntegrator
│        Main.MWE.reinit!(%1, psi0)
│   %3 = Base.getproperty(save, :integrator)::OrdinaryDiffEq.ODEIntegrator
│   %4 = Main.MWE.solve!(%3)::Any
│   %5 = Base.convert(Main.MWE.ODESolution, %4)::SciMLBase.ODESolution
│        (sol = Core.typeassert(%5, Main.MWE.ODESolution))
│   %7 = Base.getproperty(save, :integrator)::OrdinaryDiffEq.ODEIntegrator
│   %8 = Base.getproperty(%7, :u)::Any
│        Base.setproperty!(save, :result2, %8)
└──      return sol

Is there any suggestion, how to write my code in a more efficient way? Is my approach good?
Thank you to all of the great community of Julia!

I’d say ODESolution is not concrete. Perhaps you mean:

mutable struct SaveODE{Ti <: OrdinaryDiffEq.ODEIntegrator, Ts <:ODESolution}
    integrator::Ti
    sol::Ts
    tspan::Tuple{Float64,Float64}
    param::Tuple{Float64,Float64}
    result::Array{Float64,1}
    result2::Array{Float64,1}
    SaveODE() = new()
end
3 Likes

Thank you!
Indeed that solve the type instability issue.

Is there any way to improve further the allocations?

julia> @time MWE.solve_ode(new_save,ones(Float64,4));
  0.000018 seconds (102 allocations: 11.641 KiB)

(For the sake of completeness I’ll post here the full code now:)

module MWE

using DifferentialEquations

mutable struct SaveODE{Ti <: OrdinaryDiffEq.ODEIntegrator, Ts <:ODESolution}
    integrator::Ti
    sol::Ts
    tspan::Tuple{Float64,Float64}
    param::Tuple{Float64,Float64}
    result::Array{Float64,1}
    result2::Array{Float64,1}
    SaveODE{Ti , Ts}() where {Ti <: OrdinaryDiffEq.ODEIntegrator, Ts <:ODESolution} = new{Ti , Ts}()
end

function initilaize_ode(tspan::Tuple{Float64,Float64}, param::Tuple{Float64,Float64}, psi0::Array{Float64,1})
    prob = ODEProblem(flow_eqs!, psi0, tspan, param)
    integrator = init(prob, Tsit5())
    sol = solve!(integrator)

    save = SaveODE{typeof(integrator),typeof(sol)}()
    save.sol = sol
    save.integrator = integrator
    save.result = integrator.u
    save.tspan = tspan
    save.param = param
    save
end

function solve_ode(save::SaveODE, psi0::Array{Float64,1})
    reinit!(save.integrator, psi0)

    sol::ODESolution = solve!(save.integrator)
    save.result2 = save.integrator.u
    sol
end

function flow_eqs!(dpsi::Array{Float64,1}, psi::Array{Float64,1}, param::Tuple{Float64,Float64}, t::Float64)
    dpsi .= param[1] .* psi
end

end

using .MWE
using Plots
tspan = (0.0, 5.0)
param = (1.0,0.0)
new_save = MWE.initilaize_ode(tspan,param,ones(Float64,4));
sol = MWE.solve_ode(new_save,2*ones(Float64,4))
plot(sol)

Static arrays if the equation is small. Otherwise you can’t really fight against the fact that the solvers need caches.