I’m trying to use neural networks with DiffEqFlux to learn an SIR model with vaccination dynamics. The equations for this model are
function SIRx!(du, u, p, t) β, μ, γ, a, b = Float32.([280, 1/50, 365/22, 100, 0.05]) S, I, x = u du = μ*(100-x) – (β/100)*S*I - μ*S du = (β/100)*S*I - (μ+γ)*I du = a*I - b*x nothing end;
The time span I’m using to train is from 0 to 10 with initial condition
u0 = 100*Float32[0.062047128, 1.3126149f-7, 0.9486445]. I’m saving the data every 0.1 for a total of 101 points.
I want to use neural networks to approximate these equations using as little of the true information as possible. I’ve been having difficulties with the training process getting stuck in a local minimum.
Specifically, the neural network tends to fit a line straight through the middle and get stuck there.
I’ve tried all the main strategies for escaping local minima (multiple shooting, smoothed collocation, iteratively growing the fit) as listed on the DiffEqFlux documentation here, but I haven’t had much success. In all cases, the network fails to produce anything more complex than a straight line (or series of straight lines, in the case of multiple shooting).
The training framework I’m trying to use is based on the one in Stiff Neural Ordinary Differential Equations for solving the Rober equations, as those equations are roughly as complex (if not more complex) than mine.
Here is a summary of the things I’ve tried.
6 hidden layers of 5, 6, or 7 neurons each
1 hidden layer, 30, 40, or 50 neurons
Activation functions (the output layer is always linear) and the outputs are scaled by
Optimizer (I tested 500 iterations in each case):
ADAMwith learning rate 0.005, 0.05 (larger just finds divergent solutions and crashes the program)
Differential equation solver:
ADAM()with learning rate 0.05, 0.05, and 0.005
Loss function: mean absolute error or squared error, with the network prediction and the true data both scaled by the range of data values
The larger learning rates tend to find unstable solutions and crash. The ones that don’t crash end up fitting a straight line and then oscillating the slope up and down slightly until training ends. Here is the best result I have so far: 7 hidden layers,
I know the general principle works because I can use the same algorithm to fit a Lotka-Volterra model fairly well. From what I can tell, there must be something about the SIRx model (number of dimensions, stiffness, qualitative behaviour, etc.) that prevents the system from working. It could also be that I just need more patience or a faster CPU to train longer. However, with so many variables and parameters to tweak, I don’t want to devote hours of time blindly trying new combinations. Therefore, I would appreciate any advice or guidance as to how best to diagnose the problem and what strategies to invest in.
Thank you for reading, and thank you in advance for any insight.