Learning Epidemic Model With Neural Differential Equations

I’m trying to use neural networks with DiffEqFlux to learn an SIR model with vaccination dynamics. The equations for this model are

function SIRx!(du, u, p, t)
    β, μ, γ, a, b = Float32.([280, 1/50, 365/22, 100, 0.05])
    S, I, x = u
    du[1] = μ*(100-x) – (β/100)*S*I - μ*S
    du[2] = (β/100)*S*I - (μ+γ)*I
    du[3] = a*I - b*x

The time span I’m using to train is from 0 to 10 with initial condition u0 = 100*Float32[0.062047128, 1.3126149f-7, 0.9486445]. I’m saving the data every 0.1 for a total of 101 points.

I want to use neural networks to approximate these equations using as little of the true information as possible. I’ve been having difficulties with the training process getting stuck in a local minimum.
Specifically, the neural network tends to fit a line straight through the middle and get stuck there.

I’ve tried all the main strategies for escaping local minima (multiple shooting, smoothed collocation, iteratively growing the fit) as listed on the DiffEqFlux documentation here, but I haven’t had much success. In all cases, the network fails to produce anything more complex than a straight line (or series of straight lines, in the case of multiple shooting).

The training framework I’m trying to use is based on the one in Stiff Neural Ordinary Differential Equations for solving the Rober equations, as those equations are roughly as complex (if not more complex) than mine.

Here is a summary of the things I’ve tried.

Network structure:

  • 6 hidden layers of 5, 6, or 7 neurons each

  • 1 hidden layer, 30, 40, or 50 neurons

Activation functions (the output layer is always linear) and the outputs are scaled by (ymax-ymin)/(tf-t0)

  • tanh

  • softplus

  • gelu

Optimizer (I tested 500 iterations in each case):

  • ADAM with learning rate 0.005, 0.05 (larger just finds divergent solutions and crashes the program)

Differential equation solver:

  • AutoTsit5(Rosenbrock23())


  • ADAM() with learning rate 0.05, 0.05, and 0.005

Loss function: mean absolute error or squared error, with the network prediction and the true data both scaled by the range of data values

The larger learning rates tend to find unstable solutions and crash. The ones that don’t crash end up fitting a straight line and then oscillating the slope up and down slightly until training ends. Here is the best result I have so far: 7 hidden layers, gelu activation, ADAM(0.005).
training test 2

I know the general principle works because I can use the same algorithm to fit a Lotka-Volterra model fairly well. From what I can tell, there must be something about the SIRx model (number of dimensions, stiffness, qualitative behaviour, etc.) that prevents the system from working. It could also be that I just need more patience or a faster CPU to train longer. However, with so many variables and parameters to tweak, I don’t want to devote hours of time blindly trying new combinations. Therefore, I would appreciate any advice or guidance as to how best to diagnose the problem and what strategies to invest in.

Thank you for reading, and thank you in advance for any insight.

Can you show what code you tried? As written this isn’t reproducible.

The documentation suggests you shouldn’t use ADAM on these cases, and instead use the default ADAM for 300 iterations → BFGS. Without the BFGS ending you won’t get good convergence. I would assume that could be what your issue is here. Also, you may need to decrease the tolerance of the calculation for more accurate gradients.

It sounds like this new paper would help: [2006.01681] Neural Power Units. They demonstrate solving a SIR model using a new type of neural units, over Flux.jl and DifferentialEquations.jl. Their source code is available.


After some tweaking, I can get the following code to do well over a time span from 0 up to about 12. After that, I get the local minimum problem. The code for the true differential equation system is as written above.

using DiffEqFlux, OrdinaryDiffEq, Flux, Optim
using DifferentialEquations
using LinearAlgebra
using DiffEqSensitivity
using GalacticOptim
using Plots
using JLD2, FileIO

# Generate the data
u0 = 100*Float32[0.062047128, 1.3126149f-7, 0.9486445];
tspan = Float32[0, 10];
tsteps = range(tspan[1], tspan[2], step=0.02)
p = Float32[280, 1/50, 365/22, 100, 0.05];
prob = ODEProblem(SIRx!, u0, tspan, p);
sol = solve(prob, Rosenbrock23(), saveat = tsteps);
times = sol.t;
data = Array(sol);

yscale = maximum(data, dims=2) .- minimum(data, dims=2);
tscale = times[end]-times[1];
scale = yscale/tscale;

ann = FastChain(FastDense(3, 5, gelu),
	FastDense(5, 5, gelu),
	FastDense(5, 5, gelu),
	FastDense(5, 5, gelu),
	FastDense(5, 5, gelu),
	FastDense(5, 5, gelu),
	FastDense(5, 3));

function nde(du, u, p, t)
	û = ann(u, p).*scale
	du[1] = û[1]
	du[2] = û[2]
	du[3] = û[3]
losses = []
function train(p, data, times, opt, maxiters)
	function predict(θ)
	    Array(solve(prob_nn, Rosenbrock23(), p=θ, saveat=times, sensealg=ForwardDiffSensitivity()))
	function loss(θ)
		return sum(abs, (data .- pred)./yscale)/size(data,2)
	function callback(θ, l)
		push!(losses, l)
		if length(losses) % 50 == 0
			println("Loss after $(length(losses)) iterations: $(losses[end])")
	yscale = maximum(data, dims=2) .- minimum(data, dims=2)
	prob_nn = ODEProblem(nde, u0, tspan, p)
	res = DiffEqFlux.sciml_train(loss, p, opt, cb=callback, maxiters=maxiters, allow_f_increases=true)
	return res

p0 = initial_params(ann)
res = train(p0, data, times, ADAM(0.005), 10000)
res2 = train(res1.minimizer, data, times, BFGS(initial_stepnorm=0.001), 7500)

300 iterations only reduces the loss by about 0.1. It takes about 5000-10000 iterations to get to a small enough loss for BFGS to be helpful. Increasing the learning rate actually makes the loss increase over time.

I’m not sure what the default tolerance is or how low I should try to set it, but I tried the same code using Rodas5() to solve and abstol = 1e-7, reltol=1e-7. This didn’t seem to make an appreciable difference.

At this point the main thing I’d be interested to know is how to make this work for longer time spans. I would like to be able to train on time from 0 to 30 at least. As mentioned, I’ve had difficulty implementing multiple shooting. I’m also curious why the learning process is so slow with ADAM compared to the recommendation.

For reference, here’s the training function I’m using for multiple shooting. All the rest of the code is the same.

function train_multiple_shoot(p, data, times, opt, maxiters,
			groupsize, continuityterm)
	function ms_loss(data, pred)
		return sum(abs, (data .- pred)./yscale)/size(data,2)

	function continuity_loss(u1, u2)
		sum(abs, (u1 .- u2)./yscale)

	function loss(θ)
		return multiple_shoot(θ, data, times, prob, ms_loss, continuity_loss,
			Rosenbrock23(), groupsize; continuity_term=continuityterm)

	callback = function(θ, l, preds)
                push!(losses, l)
		if length(losses) % 50 == 0
			println("Loss after $(length(losses)) iterations: $(losses[end])")
		return false

	yscale = maximum(data, dims=2) .- minimum(data, dims=2)
	prob = ODEProblem(nde, data[:,1], (times[1], times[end]), p)
	res = DiffEqFlux.sciml_train(loss, p, opt, cb = callback, maxiters = maxiters,
	return res, losses

Interesting! I’ll definitely look into this. Thanks!

Did you try using a growing system?