Solving multiple ODE concurrently

Hi,

I would like to solve several ODE’s with contribution from a neural network concurrently.

I have the following (not working) example.

using StableRNGs
using ComponentArrays
using Zygote, Lux, OrdinaryDiffEq, SciMLSensitivity
using LinearAlgebra, Statistics

rng = StableRNG(1111);

rbf(x) = exp.(-(x .^ 2))

const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
              Lux.Dense(5, 2))

p, st = Lux.setup(rng, U)

const _st = st;
ps = ComponentVector{Float64}(p);

size_t = 10

# Initial Conditions
ICs = ComponentArray(x = 10.0*abs.(randn(size_t)), y = 4.0*abs.(randn(size_t))); 

# parameters
param = ComponentArray(alpha = abs.(randn(size_t)), 
                    beta = abs.(randn(size_t)),
                    gamma = abs.(randn(size_t)),
                    delta = abs.(randn(size_t)));

function ude_dynamics!(du, u, p, t, p_true)
    û = U(reshape(u[:],size_t, 2)', p, _st)[1] 
    @. du.x = p_true.alpha * u.x + û[1, :]
    @. du.y = -p_true.delta * u.y + û[2, :]
end

nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, param);
tspan = (0.0, 5.0);
ts = LinRange(0.0, 5.0, 200);

prob_nn = ODEProblem(nn_dynamics!, ICs, tspan, ps);

function predict(θ, IC, T)
    _prob = remake(prob_nn, u0 = IC, tspan = (T[1], T[end]))
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss(θ, IC, T)
    X̂ = predict(θ, IC, T)
    mean(abs2, 0.999*X̂ .- X̂) # Placeholder function. Just something to make it work
end

l, back = pullback(loss, ps, ICs, ts); 
grads = back(l)               # Problems..

When the line grads = back(l) executes, I get the error: ERROR: type Array has no field x which fallback to the line @. du.x = p_true.alpha * u.x + û[1, :]

How can I fix this?

I haven’t looked at your actual code, which in theory can be made to work, but note that the documented way is to use the ensemble interface.

https://docs.sciml.ai/SciMLSensitivity/stable/tutorials/data_parallel/

Have you given that a try yet?

Hi,
Thanks for the suggestion! I was not familiar with the ensamble interface. I will give it a go and post the code here (hopefully shortly).

I would also like to make this approach work. Any suggestion on how to solve the ERROR: type Array has no field x? It appears is has the functionality .x in the “forward” pass but losses it in the “backward”. Am I correct in this?

Hi,
The ensemble interface journey has been challenging for a new guy to Julia.

Here is some code to solve the ODE’s with the `ensemble interface´ without the neural component (just ODE).

using StableRNGs
using ComponentArrays
using Zygote, Lux, OrdinaryDiffEq, SciMLSensitivity
using LinearAlgebra, Statistics

using Distributed

rng = StableRNG(1111);

size_t = 10

# Initial Conditions
ICs = [10.0*abs.(randn(size_t))  4.0*abs.(randn(size_t))]; 

# parameters
param = [abs.(randn(size_t)) abs.(randn(size_t)) abs.(randn(size_t)) abs.(randn(size_t))];


function ude_dynamics!(du, u, p, t)
    du[1] = p[1] * u[1] - p[2] * u[2] * u[1];
    du[2] = p[3] * u[1] * u[2] - p[4] * u[2];
end

@everywhere function prob_func(prob, i, repeat)
    remake(prob, u0 = ICs[i,:], p = param[i,:])
end

tspan = (0.0, 5.0);
ts = LinRange(0.0, 5.0, 200);

prob = ODEProblem(ude_dynamics!, [0.0 0.0], tspan, [0. 0. 0. 0.]);

ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)

sim = solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10);

But I do not know how to retrieve the results or compute the gradient with respect to the solve method. I tried l, back = pullback(solve(ensemble_prob, Tsit5(), EnsembleThreads(), trajectories = 10) ) [does not work]. Most likely I am missing something.

Is the ensemble interface solving each ODE in its own thread? So function prob_func(prob, i, repeat) controls the behavior of each solve for each trajectory?

What am I missing in the original post to make the ComponentArrays approach to work?