Trainning UODE over multiple species

I am trying to model the Lotka-Volterra system for multiple species.

The system is defined as:

\left\{\begin{matrix} \frac{\mathrm{d} x}{\mathrm{d} t}=\alpha x - \beta xy\\ \frac{\mathrm{d} y}{\mathrm{d} t}=-\delta y + \gamma xy \end{matrix}\right.

I know the parameter \alpha, \beta, gamma, \delta for each species system.

I am trying to model the system as:

\left\{\begin{matrix} \frac{\mathrm{d} x}{\mathrm{d} t}=\alpha x - \beta NN_1\\ \frac{\mathrm{d} y}{\mathrm{d} t}=-\delta y + \gamma NN_2 \end{matrix}\right.

To catch the xy interaction term.

For this toy case, I have synthetically generated data for 3 predator-prey systems and saved the data to corresponding files.
I want to train the network over any amount of predator-prey systems I have.

For this I am envisioning:
Defining network, the UODE problem, loss functions and optimization environment.
Define an outer loop for epochs
Define an inner loop for each file I have with species data.
Train the network for N epochs with ADAM and later train the network with LBFGS

However, the code I am using is not producing the expect results and I would really appreciate some help with understanding what is wrong and solving the problem.

The code is listed bellow:


# Packages
using Pkg
Pkg.activate(".")

# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs


# Set a random seed for reproducible behaviour
rng = StableRNG(1111)

Generate the necessary data:

function lotka!(du, u, p, t)

    α, β, γ, δ = p

    du[1] = α * u[1] - β * u[2] * u[1]

    du[2] = γ * u[1] * u[2] - δ * u[2]

end

# Define the experimental parameter

t_true= LinRange(0.0,5.0,300)

tspan = (0.0, 5.0)

u0_wolf_rabit = 5.0f0 * rand(rng, 2)

p_wolf_rabit = [1.3, 0.9, 0.8, 1.8]

prob_wolf_rabit = ODEProblem(lotka!, u0_wolf_rabit, tspan, p_wolf_rabit)

solution_wolf_rabit = solve(prob_wolf_rabit, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = t_true)

u0_cat_mouse = 4.0f0 * rand(rng, 2)

p_cat_mouse = [1.5, 1.1, 1.0, 2.0]

prob_cat_mouse = ODEProblem(lotka!, u0_cat_mouse, tspan, p_cat_mouse)

solution_cat_mouse = solve(prob_cat_mouse, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = t_true)

u0_human_mamuth = 3.0f0 * rand(rng, 2)

p_human_mamuth = [1.1, 0.7, 0.6, 1.6]

prob_human_mamuth = ODEProblem(lotka!, u0_human_mamuth, tspan, p_human_mamuth)

solution_human_mamuth = solve(prob_human_mamuth, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = t_true)

plt= plot(solution_wolf_rabit, alpha = 0.75, color = :black, label = ["wolf-rabit" nothing])

plot!(plt,solution_cat_mouse, alpha = 0.75, color = :red, label = ["cat-mouse" nothing] )

plot!(plt,solution_human_mamuth, alpha = 0.75, color = :blue, label = ["human-mamuth" nothing] )

Save and read data functions


using DataFrames
using CSV

function write_data(data, file_name)
    t, p, res = data;
    α, β, γ, δ = p;
    x, y = res;

    n_points = size(t)[1]

    names = ["t", "alpha", "beta", "gamma", "delta", "x", "y"];
    df = DataFrame(
                    "t" => t,
                    "alpha" => ones(n_points).*α,
                    "beta" => ones(n_points).*β,
                    "gamma" => ones(n_points).*γ,
                    "delta" => ones(n_points).*δ,
                    "x" => ones(n_points).*x,
                    "y" => ones(n_points).*y,
                    );

    CSV.write(file_name, df);
end

function read_data(file_name)
    data = CSV.File(file_name)
    
    time = data.t;
    α = data.alpha[1];
    β = data.beta[1];
    γ = data.gamma[1];
    δ = data.delta[1];

    p = [α β γ δ];
    res = [data.x data.y];

    return [time, p, res];
end

write_data([t_true, p_wolf_rabit, Array(solution_wolf_rabit)], "wolf_rabit.csv")
write_data([t_true, p_cat_mouse, Array(solution_cat_mouse)], "cat_mouse.csv")
write_data([t_true, p_human_mamuth, Array(solution_human_mamuth)], "human_mamuth.csv")
# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + p_true[2]*û[1]
    du[2] = -p_true[4] * u[2] + p_true[3]*û[2]
end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
function predict(θ, X = X_true[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6))
end
function loss(θ)
    X̂ = predict(θ)
    mean(abs2, X_true .- X̂)
end
losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
# Training loop

epochs = 100;

files = ["wolf_rabit.csv"];

for epoch in 1:epochs
    for file in files
        t_true, p_, X_true = read_data(file);
        X_true = X_true';
        
        tspan = (t_true[1], t_true[end]);

        global p_true= p_;
        
        # Define the problem
        global prob_nn = ODEProblem(nn_dynamics!, X_true[:, 1], tspan, p)

        res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 1)
    end

    if epoch % 10 == 0
        println("Training loss after $(epoch) epoch: $(losses[end])")
    end
end

However, even for a single file I do not get the same behavior has I would get without the loop . And I do not know if defining p_true and prob_nn as global variables is a good idea.

Can you give me a couple of hints in what to improve?