Nicer way to plug an R function into the RHS of an ODEProblem?

Hi All,

I’m trying to plug an R function into an ODEProblem in OrdinaryDiffEq. The below code works, but is there a simple way to make the code neater e.g. define ODEProblem in terms of an RObject?

using OrdinaryDiffEq
using RCall
using Plots

R"""
sir_ode_r <- function(u,p,t){
    S <- u[1]
    I <- u[2]
    R <- u[3]
    N <- S+I+R
    beta <- p[1]
    cee <- p[2]
    gamma <- p[3]
    dS <- -beta*cee*I/N*S
    dI <- beta*cee*I/N*S - gamma*I
    dR <- gamma*I
    return(c(dS,dI,dR))
}
"""

function sir_ode_jl(u,p,t)
    robj = rcall(:sir_ode_r, u, p, t)
    return convert(Array,robj)
end

δt = 0.1
tmax = 40.0
tspan = (0.0,tmax)
u0 = [990.0,10.0,0.0] # S,I,R
p = [0.05,10.0,0.25] # β,c,γ
prob = ODEProblem(sir_ode_jl, u0, tspan, p)
sol = solve(prob, Tsit5(), dt = δt)
plot(sol)
1 Like

For this kind of function, you can use ModelingToolkit to translate it to a symbolic form and generate the Julia function. That would also fix the performance. That’s how the diffeqr/diffeqpy JIT compiler works.

Otherwise, I think this is how I would expect it to look.

Cool - can you show a simple example (modeltoolkitize?)

Would this also work with my PythonCall example?

using OrdinaryDiffEq
using PythonCall
using Plots

@pyexec """
def sir_ode_py(u,p,t):
    S = u[0]
    I = u[1]
    R = u[2]
    N = S+I+R
    beta = p[0]
    c = p[1]
    gamma = p[2]
    dS = -beta*c*I/N*S
    dI = beta*c*I/N*S - gamma*I
    dR = gamma*I
    return [dS,dI,dR]
""" => sir_ode_py

sir_ode_jl(u,p,t) = pyconvert(Array{Float64}, sir_ode_py(u, p, t))

δt = 0.1
tmax = 40.0
tspan = (0.0,tmax)
u0 = [990.0,10.0,0.0] # S,I,R
p = [0.05,10.0,0.25] # β,c,γ
prob = ODEProblem{false}(sir_ode_jl, u0, tspan, p)
sol = solve(prob, Tsit5(), dt = δt)
plot(sol)

It should. Try just calling modelingtoolkitize(prob). If you make that convert into an Array instead of forcing Array{Float64} you should be good.

For Python, this code works (converting to Array rather than Array{Float64}

using ModelingToolkit
@named sys = modelingtoolkitize(prob)
prob_mtk = ODEProblem(sys, u0, tspan, p)
sol_mtk = solve(prob_mtk, Tsit5(), dt = δt)
plot(sol_mtk)

However, modelingtoolkitize breaks for the R version: ERROR: MethodError: no method matching sexpclass(::Num)

(yes, I know I’m not using MTK9 :slight_smile: )

PS. Is there any way to pass the names of states and parameters when using modelingtoolkitize so that the resulting equations are easier to read?

No, but it wouldn’t take more than 15 minutes to add. Open an issue so hopefully I don’t forget. At a conference right now but it’ll end up in the email list that way.