Constraints in NeuralPDE.NNODE?

Hi everyone,

I’m trying to implement a toy example of running NeuralPDE.NNODE to add to my repository of SIR models. Here’s what I have, which doesn’t work well.

using OrdinaryDiffEq
using NeuralPDE
using Flux
using OptimizationOptimisers

function sir_ode(u,p,t)
    (S, I, R) = u
    (β, γ) = p
    dS = -β*I*S
    dI = β*I*S - γ*I
    dR = γ*I
    [dS, dI, dR]

tspan = (0.0,40.0)
u0 = [0.99, 0.01, 0.0];
p = [0.5, 0.25];
prob_ode = ODEProblem(sir_ode, u0, tspan, p);
sol_ode = solve(prob_ode, Tsit5(), saveat=δt);

num_hid = 8
chain = Flux.Chain(Dense(1, num_hid, σ), Dense(num_hid, 3))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNODE(chain, opt)
sol_nnode = solve(prob_ode,
                  NeuralPDE.NNODE(chain, opt),

out = Array(sol_nnode)

This doesn’t give a good solution as my state vector u = [S,I,R] is (a) constrained to be positive and (b) constrained to sum to 1. My understanding was that the Flux.Chain represents a mapping from t to u such that u'(t) matches the output of the sir_ode function, but when I add a softmax layer to the above, it doesn’t give output out where S(t)+I(t)+R(t)=1.0. Any suggestions for a neural net architecture that works?

In the NeuralPDE.jl docs there is a section on how to add constraints. Can this also be done with the simplified NNODE interface too, and is it possible to implement the above constraints using this approach?

Here’s a simpler example that scales a lot better - it has a single variable, the time domain is [0,1] and the variable lies between 0 and 1. However, NNODE still fails spectacularly. I’m sure I’m making a simple mistake, but can anyone spot why?

using OrdinaryDiffEq
using NeuralPDE
using Flux
using OptimizationOptimisers
using Plots

function si_ode(u,p,t)
    I = u[1]
    S = 1.0 - I
    dI = 20.0*S*I

prob_ode = ODEProblem(si_ode, [0.01], (0.0, 1.0), []);
sol_ode = solve(prob_ode, Tsit5(), saveat=0.025);

numhid = 4
chain = Flux.Chain(Dense(1, numhid, σ), Dense(numhid, numhid, σ), Dense(numhid, 1))
opt = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNODE(chain, opt)
sol_pinn = solve(prob_ode, alg; dt = 0.025, verbose=true, abstol=1f-10, maxiters=50000)

plot(sol_ode, label="ODE", xlabel="t",ylabel="I")
plot!(sol_pinn.t, Array(sol_pinn)',label="PINN")


I also tried a simpler (but basically equivalent) problem using PINNs directly - still no success. Any idea why this code can’t fit a simple logistic curve?

using ModelingToolkit
using OrdinaryDiffEq
using NeuralPDE
using DomainSets
using Flux
using Optimization
using OptimizationOptimJL
using OptimizationOptimisers
using Plots

@parameters t
@variables x(..)
Dt = Differential(t)
eqs = [Dt(x(t)) ~ x(t)*(1-x(t))]

@named ode_sys = ODESystem(eqs)
ode_prob = ODEProblem(ode_sys, [0.01], (0.0,10.0), [])
ode_sol = solve(ode_prob, Tsit5(), saveat=0.1)

bcs = [x(0) ~ 0.01]
domains = [t ∈ Interval(0.0, 10.0)];
@named pde_sys = PDESystem(eqs, bcs, domains, [t], [x(t)])

numhid = 32
chain = [Flux.Chain(Flux.Dense(1, numhid, Flux.σ), Flux.Dense(numhid, 1))]
grid_strategy = NeuralPDE.GridTraining(0.1)
discretization = NeuralPDE.PhysicsInformedNN(chain, grid_strategy)
pde_prob = NeuralPDE.discretize(pde_sys, discretization)

global i=1
callback = function (p,l)
    println("Epoch $i: Current loss is: $l")
    global i += 1
    return false

res = Optimization.solve(pde_prob, OptimizationOptimisers.Adam(0.1); callback=callback, maxiters=5000, abstol = 1e-10, reltol = 1e-10)
pde_prob = remake(pde_prob, u0 = res.minimizer)
res = Optimization.solve(pde_prob, OptimizationOptimJL.BFGS(); callback=callback, maxiters=5000, abstol = 1e-10, reltol = 1e-10)
phi = discretization.phi
ts = [infimum(d.domain):0.1:supremum(d.domain) for d in domains][1]
xpred  = hcat([phi[1]([t],res.u) for t in ts]...)'

Can you open an issue? We need to turn the better fitting/modeling methods of Push nnode further · Issue #46 · SciML/NeuralPDE.jl · GitHub into tutorials / standard usage of NNODE (for example, it should try fitting on a small interval and grow the interval, etc.). If you open an issue we can use your example as a test case in the next GSoC.

Hi Chris,

Done! Here’s the issue.