Using Neural ODEs to learn a family of ODEs (with Automatic Differentiation)

In case anyone runs in to the same problem, I settled on this solution.

The best solution I could think of is to have the time and control parameter of the original ODE be part of the state in the NeuralODE. I also switched to using DiffEqFlux in the hopes that I would be less likely to encounter any errors related to differentiating the neural network, but I think in principle this will also work for OrdinaryDiffEq.

I replace the original ODE with an “extended” ODE which includes the time and lambda as part of the state vector (I don’t think time is necessary, but it will be in the application I have in mind). This way, the neural network learns the time-constant parameter lambda and the linear time.

I don’t love this solution, since the time and parameter lambda are known a priori and shouldn’t have to be learned, but if it works, it works.


import OrdinaryDiffEq as ODE
import DiffEqFlux as DEF
import Optimization as OPT
import Optimisers
import OptimizationOptimisers as OPO
import Lux
import Zygote
import Optim
import SciMLSensitivity as SMS
import Random: MersenneTwister
import ComponentArrays: ComponentArray
import ComponentArrays
import Plots


# Set up dahlquist equation, solve an ensmble of problems with different λ

rng = MersenneTwister(1)
u0 = Float32[1,0,0]
datasize=11
tspan = (0.0f0, 1.0f0)
tsteps = range(tspan[1], tspan[2]; length=datasize)

num_training_samples = 3
training_λs = range(-1.0f0, 1.0f0, length=num_training_samples)


"""
Solve the dahlquist test equation,
    dw/dt = λ*w,
where u stores [u, λ, t]
"""
function dudt!(du, u, p, t)
    u_u, u_λ, u_t = u
    du[1] = u_λ*u_u
    du[2] = 0
    du[3] = 1
    return du
end


dahlquist_prob = ODE.ODEProblem(dudt!, u0, tspan)

# Vector of initial conditions, suitable for ODE Ensemble
initial_conditions_vec = [Float32[1, λ, 0] for λ in training_λs]
# Hstacked matrix of initial conditions, suitable for NN input
initial_conditions_mat = reduce(hcat, initial_conditions_vec)

# Solve Ensemble of ODEs to get training data
function prob_func(prob, i, repeat)
    return ODE.remake(prob; u0=initial_conditions_vec[i])
end

dahlquist_ensemble = ODE.EnsembleProblem(dahlquist_prob, prob_func=prob_func)
dahlquist_sol = ODE.solve(dahlquist_ensemble, ODE.Tsit5(), ODE.EnsembleSerial(); trajectories=num_training_samples, saveat=tsteps)
dahlquist_sol_ary = permutedims(Array(dahlquist_sol), (1,3,2)) # Different initial conditions should change with 2nd axis for NN compatibility


neural_dudt = Lux.Chain(
    Lux.Dense(3, 8, tanh), # 3 inputs: u, t, λ
    Lux.Dense(8, 3)
)
p, st = Lux.setup(rng, neural_dudt)

prob_neuralode = DEF.NeuralODE(neural_dudt, tspan, ODE.Tsit5(); saveat=tsteps)

function predict_neuralode(p)
    return prob_neuralode(initial_conditions_mat, p, st) |> first |> Array
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, dahlquist_sol_ary .- pred)
    return loss
end


# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
function callback(state, l; doplot = true)
    println(l)
    # plot current prediction against data
    if doplot
        pred = predict_neuralode(state.u)
        plt = Plots.scatter(tsteps, dahlquist_sol_ary[1, end, :]; label = "data")
        Plots.scatter!(plt, tsteps, pred[1, end, :]; label = "prediction")
        display(Plots.plot(plt))
    end
    return false
end

pinit = ComponentArray(p)
adtype = OPT.AutoZygote() # This does *not* mean zygote is used for the NeuralODE

# Run the callback once for initial plot
callback((; u = pinit), loss_neuralode(pinit); doplot = true)

optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = OPT.OptimizationProblem(optf, pinit)

result_neuralode = OPT.solve(
    optprob, OPO.Adam(0.05); callback = callback, maxiters = 300
)


1 Like