Hi everyone, I’m trying to solve ODE many times each with different initial conditions, and I would like to optimize my code for that.
I’ve tried to save in a (mutable) struct the solver in order to save time and the need to create each time the solver.
My code goes something like
module MWE
using DifferentialEquations
mutable struct SaveODE
integrator::OrdinaryDiffEq.ODEIntegrator
sol::ODESolution
tspan::Tuple{Float64,Float64}
param::Tuple{Float64,Float64}
result::Array{Float64,1}
result2::Array{Float64,1}
SaveODE() = new()
end
function initilaize_ode(save::SaveODE, psi0::Array{Float64,1})
prob = ODEProblem(flow_eqs!, psi0, save.tspan, save.param)
integrator = init(prob, Tsit5())
save.integrator = integrator
save.sol = solve!(integrator)
save.result = integrator.u
end
function solve_ode(save::SaveODE, psi0::Array{Float64,1})
reinit!(save.integrator, psi0)
sol::ODESolution = solve!(save.integrator)
save.result2 = save.integrator.u
sol
end
function flow_eqs!(dpsi::Array{Float64,1}, psi::Array{Float64,1}, param::Tuple{Float64,Float64}, t::Float64)
dpsi .= param[1] .* psi
end
end
using .MWE
using Plots
new_save = MWE.SaveODE()
new_save.param = (-1.0,0.0)
new_save.tspan = (0.0, 5.0)
MWE.initilaize_ode(new_save,ones(Float64,4))
sol = MWE.solve_ode(new_save,2*ones(Float64,4))
plot(sol)
and if I try to analyze my code I get more allocations than I expected and red warning in @code_warntype
:
julia> @time MWE.solve_ode(new_save,ones(Float64,4));
0.000029 seconds (111 allocations: 12.656 KiB)
julia> @code_warntype MWE.solve_ode(new_save,ones(Float64,4));
Variables
#self#::Core.Const(Main.MWE.solve_ode)
save::Main.MWE.SaveODE
psi0::Vector{Float64}
sol::SciMLBase.ODESolution
Body::SciMLBase.ODESolution
1 ─ %1 = Base.getproperty(save, :integrator)::OrdinaryDiffEq.ODEIntegrator
│ Main.MWE.reinit!(%1, psi0)
│ %3 = Base.getproperty(save, :integrator)::OrdinaryDiffEq.ODEIntegrator
│ %4 = Main.MWE.solve!(%3)::Any
│ %5 = Base.convert(Main.MWE.ODESolution, %4)::SciMLBase.ODESolution
│ (sol = Core.typeassert(%5, Main.MWE.ODESolution))
│ %7 = Base.getproperty(save, :integrator)::OrdinaryDiffEq.ODEIntegrator
│ %8 = Base.getproperty(%7, :u)::Any
│ Base.setproperty!(save, :result2, %8)
└── return sol
Is there any suggestion, how to write my code in a more efficient way? Is my approach good?
Thank you to all of the great community of Julia!