Using Neural ODEs to learn a family of ODEs (with Automatic Differentiation)

I am running into issues when trying to use reverse-mode automatic differentiation via Zygote when trying to solve families of

Given u(t), we can train a neural ODE to find the underlying ODE du/dt = f(u, t).

For example, consider the dahlquist equation du/dt = lambda*u, u(0)=u0. We replace the righthand side by a neural network f(u,t). Then for a variety of initial conditions, we numerically simulate du/dt = lambda*u to get u(t). To train the neural ODE, for each initial condition we simulate du/dt = f(u,t) to get a prediction u_pred(t), and perform a gradient-based search using |u_pred(t) - u(t)| as the loss-function. In other words, the initial conditions are our features, and the resulting u(t) are our labels. Because we may have many parameters in our neural network, it is important to be able to use reverse-mode AD to compute the gradient.

Here is an example like that from the SciML documentation. They use OrdinaryDiffEq.remake to change the initial condition of the neural ODE problem.

Instead of varying the initial condition, I want to vary lambda (or both). I.e. I want to find du/dt = f(u,t,lambda), where values of lambda are our features, and again the various u(t) are our labels (in my real work lambda will be some control vector that changes du/dt in a more complicated way).

However, I am having trouble using Zygote to compute gradients of the loss function, since the neural network f is parameterized by both lambda and the neural network weights, but only the neural network weights are trainable, and the neural network weights (as used by Lux.jl) are in NamedTuple format, which causes issues for Zygote. To work around this, I tried using ComponentArrays.jl, but that has caused other issues, seemingly because when I gather lambda and the neural network weights to pass to f, Zygote is unable to infer the relationship between the gathered ComponentArray and the input ComponentArray (which consists only of the neural network weights).

In the following MWE, test_grad is (nothing,). I would appreciate any advice on how I could solve the problem with a similar setup, but computing the gradient correctly.

import OrdinaryDiffEq as ODE
import Optimization as OPT
import OptimizationOptimisers as OPO
import Lux
import Zygote
import SciMLSensitivity
using Random: MersenneTwister
using ComponentArrays: ComponentArray

sensealg = SciMLSensitivity.BacksolveAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())

# Set up dahlquist equation, solve an ensmble of problems with different λ

u0 = 1.0f0
tspan = (0.0f0, 1.0f0)
saveat = 0.1f0
dummy_λ = -1.0f0

"""
Compute du/dt for the dahlquist test equation
    du/dt = λ * u
"""
function dudt_dahlquist(u, p, t)
    λ = p
    return λ * u
end

prob_dahlquist_template = ODE.ODEProblem(dudt_dahlquist, u0, tspan, dummy_λ)
num_training_samples = 10
# Randomly sampled λ values in [-0.5, 0.5]
training_λs = 0.5f0 .- rand(MersenneTwister(0), Float32, num_training_samples)

function dahlquist_prob_func(prob, i, repeat)
    return ODE.remake(prob, p = training_λs[i])
end

ensemble_prob = ODE.EnsembleProblem(prob_dahlquist_template, prob_func = dahlquist_prob_func)
solution_ensemble = ODE.solve(
    ensemble_prob,
    ODE.Tsit5(),
    sensealg = sensealg,
    abstol = 1e-6,
    reltol = 1e-6,
    saveat = saveat,
    trajectories = num_training_samples, # Determines how many problems to solve in ensemble
)
ensemble_u_true = Array(solution_ensemble)


# Set up the neural ODE to learn dahlquist equation

f_NN = Lux.Chain(
    Lux.Dense(3, 10, tanh), # 3 inputs: u, t, λ
    Lux.Dense(10, 1)
)
f_NN_weights_init, _st = Lux.setup(MersenneTwister(1), f_NN)

function dudt_dahlquist_NN(u, p, t)
    NN_input = vcat(u, t, p.dahlquist_params)
    NN_val, _ = Lux.apply(f_NN, NN_input, p.NN_weights, _st)
    val = NN_val[1]
    return val
end

dummy_params = ComponentArray(dahlquist_params=dummy_λ, NN_weights=f_NN_weights_init)
prob_dahlquist_NN_template = ODE.ODEProblem(dudt_dahlquist_NN, u0, tspan, dummy_params)


function loss_function(trainable_params)
    NN_weights = trainable_params.NN_weights
    function dahlquist_NN_prob_func(prob, i, repeat)
        p = ComponentArray(dahlquist_params=training_λs[i], NN_weights=NN_weights)
        return ODE.remake(prob, p = p)
    end

    ensemble_prob_NN = ODE.EnsembleProblem(prob_dahlquist_NN_template, prob_func = dahlquist_NN_prob_func)

    ensemble_sol_pred = ODE.solve(
        ensemble_prob_NN,
        ODE.Tsit5(),
        ODE.EnsembleSerial(),
        abstol = 1e-6,
        reltol = 1e-6,
        saveat = saveat,
        sensealg = sensealg,
        trajectories = num_training_samples, # Determines how many problems to solve in ensemble
    )
    ensemble_u_pred = Array(ensemble_sol_pred) # Ensemble solutions are concatenated along 3rd dimension

    loss_val = sum(abs2, ensemble_u_pred .- ensemble_u_true) / length(ensemble_u_true)
    return loss_val
end

test_grad = Zygote.gradient(loss_function, dummy_params)
1 Like

This should just work? Are you testing on v1.11?

1 Like

I am using v1.11, I’ll try v1.12, I hadn’t considered that.

By “this”, do you mean using component arrays, or are you saying it should be fine to just have the optimization input (the NN weights) be a NamedTuple?

No, I say this because v1.12 won’t work :sweat_smile:

Hmm then I don’t know what it is off the top of my head. Let me play with it.

1 Like

Okay, thanks for considering it! If it helps, here is a slightly simpler MWE where the training data isn’t generated using an ODE. I still get a gradient equal to (nothing,)

import OrdinaryDiffEq as ODE
import Lux
import Zygote
import SciMLSensitivity
using Random: MersenneTwister
using ComponentArrays: ComponentArray

sensealg = SciMLSensitivity.BacksolveAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())

# Set up dahlquist equation, solve an ensmble of problems with different λ

u0 = [1.0f0]
tspan = (0.0f0, 1.0f0)
saveat_step = 0.1f0
dummy_λ = -1.0f0
saveat_grid = range(tspan[1], tspan[2], step=saveat_step)
num_training_samples = 3
training_λs = range(-1.0f0, 1.0f0, length=num_training_samples)


"""
Analytically solve the dahlquist test equation
    du/dt = λ * u, u(0)=1
"""
function u(t, λ)
    return exp(λ*t)
end

true_us = [u(t, λ) for _ in 1:1, t in saveat_grid, λ in training_λs]

# Set up the neural ODE to learn dahlquist equation
f_NN = Lux.Chain(
    Lux.Dense(3, 10, tanh), # 3 inputs: u, t, λ
    Lux.Dense(10, 1)
)
f_NN_weights_init, _st = Lux.setup(MersenneTwister(1), f_NN)

"""
Compute du/dt using
    du/dt = f(u, λ, t)
where f is a neural network.
"""
function dudt_NN!(du, u, p, t)
    NN_input = vcat(u, t, p.dahlquist_params)
    NN_val, _ = Lux.apply(f_NN, NN_input, p.NN_weights, _st)
    val = NN_val[1]
    du[1] = val
end

dummy_params = ComponentArray(dahlquist_params=dummy_λ, NN_weights=f_NN_weights_init)
prob_NN_template = ODE.ODEProblem(dudt_NN!, u0, tspan, dummy_params)

function loss_function(trainable_params)

    NN_weights = trainable_params.NN_weights

    function NN_prob_func(prob, i, repeat)
        p = ComponentArray(dahlquist_params=training_λs[i], NN_weights=NN_weights)
        return ODE.remake(prob, p = p)
    end

    ensemble_prob_NN = ODE.EnsembleProblem(prob_NN_template, prob_func = NN_prob_func)
    ensemble_sol_pred = ODE.solve(
        ensemble_prob_NN, ODE.Tsit5(), ODE.EnsembleSerial(),
        abstol = 1e-6, reltol = 1e-6, 
        saveat = saveat_grid,
        sensealg = sensealg,
        trajectories = num_training_samples, # Determines how many problems to solve in ensemble
    )
    ensemble_u_pred = Array(ensemble_sol_pred) # Ensemble solutions are concatenated along 3rd dimension

    loss_val = sum(abs2, ensemble_u_pred .- true_us) / length(true_us)
    return loss_val
end

test_grad = Zygote.gradient(loss_function, dummy_params)

Sort of solved.

I realized that although I do need the parameters of the ODE problem to be a ComponentArray, the input to the loss function does not need to be one. So I just pass the neural network weights to the loss function as a NamedTuple, and only create a ComponentArray when using remake to create an ensemble of neural ODE problems (see code block below).

When doing it this way, I seem to be able to compute the gradient (the output is a NamedTuple of the same type as the input weights).

This kind of solves the issue as original stated, but my understanding is that in the context of an optimization I will want to have the weights input as a ComponentArray so that we can do broadcasted adds to perform the gradient descent. So this workaround wouldn’t work when I am actually trying to optimize the ODE.

import OrdinaryDiffEq as ODE
import Lux
import Zygote
import SciMLSensitivity
using Random: MersenneTwister
using ComponentArrays: ComponentArray

sensealg = SciMLSensitivity.BacksolveAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())

# Set up dahlquist equation, solve an ensmble of problems with different λ

u0 = [1.0f0]
tspan = (0.0f0, 1.0f0)
saveat_step = 0.1f0
dummy_λ = -1.0f0
saveat_grid = range(tspan[1], tspan[2], step=saveat_step)
num_training_samples = 3
training_λs = range(-1.0f0, 1.0f0, length=num_training_samples)


"""
Analytically solve the dahlquist test equation
    du/dt = λ * u, u(0)=1
"""
function u(t, λ)
    return exp(λ*t)
end

true_us = [u(t, λ) for _ in 1:1, t in saveat_grid, λ in training_λs]

# Set up the neural ODE to learn dahlquist equation
f_NN = Lux.Chain(
    Lux.Dense(3, 10, tanh), # 3 inputs: u, t, λ
    Lux.Dense(10, 1)
)
f_NN_weights_init, _st = Lux.setup(MersenneTwister(1), f_NN)

"""
Compute du/dt using
    du/dt = f(u, λ, t)
where f is a neural network.
"""
function dudt_NN!(du, u, p, t)
    NN_input = vcat(u, t, p.dahlquist_params)
    NN_val, _ = Lux.apply(f_NN, NN_input, p.NN_weights, _st)
    val = NN_val[1]
    du[1] = val
end

dummy_params = ComponentArray(dahlquist_params=dummy_λ, NN_weights=f_NN_weights_init)
prob_NN_template = ODE.ODEProblem(dudt_NN!, u0, tspan, dummy_params)

function loss_function(NN_weights)

    function NN_prob_func(prob, i, repeat)
        p = ComponentArray(dahlquist_params=training_λs[i], NN_weights=NN_weights)
        return ODE.remake(prob, p = p)
    end

    ensemble_prob_NN = ODE.EnsembleProblem(prob_NN_template, prob_func = NN_prob_func)
    ensemble_sol_pred = ODE.solve(
        ensemble_prob_NN, ODE.Tsit5(), ODE.EnsembleSerial(),
        abstol = 1e-6, reltol = 1e-6, 
        saveat = saveat_grid,
        sensealg = sensealg,
        trajectories = num_training_samples, # Determines how many problems to solve in ensemble
    )
    ensemble_u_pred = Array(ensemble_sol_pred) # Ensemble solutions are concatenated along 3rd dimension

    loss_val = sum(abs2, ensemble_u_pred .- true_us) / length(true_us)
    return loss_val
end

test_grad = Zygote.gradient(loss_function, f_NN_weights_init)

Yes, I was just going to post that the issue is that if you’re using ZygoteVJP, which I wouldn’t recommend BTW for in-place ODE functions, but if you do force that then you must use an abstract array type, as it will make a Zygote.Buffer which requires array properties.

I see. I’m new to the SciML ecosystem, so I just followed the first example I could find online. What would you recommend for in-place ODE functions?

FYI I am still at an impasse because if I do

adtype = OPT.AutoZygote()
opt = OPO.Adam(0.1)
optf = OPT.OptimizationFunction((x,p) -> loss_function(x), adtype)
optprob = OPT.OptimizationProblem(optf, f_NN_weights_init)
optsol = OPT.solve(optprob, opt; verbose=true, maxiters = 5)

I get an error saying that NamedTuple’s (which f_NN_weights_init is) cannot be copied. But if I wrap f_NN_weights_init in a ComponentArray the gradient will be (nothing,). I guess for the time being, since I am capable of computing the gradient using Zygote.gradient, the most straightforward thing would be to do the optimization myself or with another package.

Update: actually it turns out that even when the input is a NamedTuple, the gradient is not being computed correctly. Although the output is a length-1 tuple whose sole entry is a NamedTuple with the same structure as f_NN_weights_init, the data is all zeros:

julia> test_grad = Zygote.gradient(loss_function, f_NN_weights_init)
((layer_1 = (weight = Float32[0.0 0.0 0.0; 0.0 0.0 0.0; … ; 0.0 0.0 0.0; 0.0 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), layer_2 = (weight = Float32[0.0 0.0 … 0.0 0.0], bias = Float32[0.0])),)

What tutorial was it?

If you just don’t pass the sensealg it’ll likely give you a much better choice.

The example was: Faster Neural Ordinary Differential Equations with SimpleChains · SciMLSensitivity.jl

On second look, they do mention that they used QuadratureAdjoint(autojacvec=ZygoteVJP()) because “All the adjoint methods make heavy use of in-place mutation to be performant with the heap allocated normal arrays. For our statically sized, stack allocated StaticArrays, in order to be able to compute the ODE adjoint we need to do everything out of place.” I changed QuadratureAdjoint to BackSolveAdjoint after QuadratureAdjoint did not work.

So it is on me for not reading the documentation thoroughly. Still, other examples I have seen use adtype = OPT.AutoZygote(), so I defaulted to Zygote. I don’t know why, but I was under the impression that sensealg had to be provided manually.

Ahh that’s kind of a special example because it’s using SimpleChains, most cases shouldn’t do that. In fact, SimpleChains in the recent versions probably doesn’t even need that anymore.

That’s the outer AD, there’s two ADs and usually the inner one (which has more of an effect on performance) is auto-chosen.

1 Like

Is there a way to check which inner AD method is used when I don’t specify it? As I scale up the number of parameters in the neural network, I would like to make sure that Reverse-mode automatic differentiation is being used.

In case anyone runs in to the same problem, I settled on this solution.

The best solution I could think of is to have the time and control parameter of the original ODE be part of the state in the NeuralODE. I also switched to using DiffEqFlux in the hopes that I would be less likely to encounter any errors related to differentiating the neural network, but I think in principle this will also work for OrdinaryDiffEq.

I replace the original ODE with an “extended” ODE which includes the time and lambda as part of the state vector (I don’t think time is necessary, but it will be in the application I have in mind). This way, the neural network learns the time-constant parameter lambda and the linear time.

I don’t love this solution, since the time and parameter lambda are known a priori and shouldn’t have to be learned, but if it works, it works.


import OrdinaryDiffEq as ODE
import DiffEqFlux as DEF
import Optimization as OPT
import Optimisers
import OptimizationOptimisers as OPO
import Lux
import Zygote
import Optim
import SciMLSensitivity as SMS
import Random: MersenneTwister
import ComponentArrays: ComponentArray
import ComponentArrays
import Plots


# Set up dahlquist equation, solve an ensmble of problems with different λ

rng = MersenneTwister(1)
u0 = Float32[1,0,0]
datasize=11
tspan = (0.0f0, 1.0f0)
tsteps = range(tspan[1], tspan[2]; length=datasize)

num_training_samples = 3
training_λs = range(-1.0f0, 1.0f0, length=num_training_samples)


"""
Solve the dahlquist test equation,
    dw/dt = λ*w,
where u stores [u, λ, t]
"""
function dudt!(du, u, p, t)
    u_u, u_λ, u_t = u
    du[1] = u_λ*u_u
    du[2] = 0
    du[3] = 1
    return du
end


dahlquist_prob = ODE.ODEProblem(dudt!, u0, tspan)

# Vector of initial conditions, suitable for ODE Ensemble
initial_conditions_vec = [Float32[1, λ, 0] for λ in training_λs]
# Hstacked matrix of initial conditions, suitable for NN input
initial_conditions_mat = reduce(hcat, initial_conditions_vec)

# Solve Ensemble of ODEs to get training data
function prob_func(prob, i, repeat)
    return ODE.remake(prob; u0=initial_conditions_vec[i])
end

dahlquist_ensemble = ODE.EnsembleProblem(dahlquist_prob, prob_func=prob_func)
dahlquist_sol = ODE.solve(dahlquist_ensemble, ODE.Tsit5(), ODE.EnsembleSerial(); trajectories=num_training_samples, saveat=tsteps)
dahlquist_sol_ary = permutedims(Array(dahlquist_sol), (1,3,2)) # Different initial conditions should change with 2nd axis for NN compatibility


neural_dudt = Lux.Chain(
    Lux.Dense(3, 8, tanh), # 3 inputs: u, t, λ
    Lux.Dense(8, 3)
)
p, st = Lux.setup(rng, neural_dudt)

prob_neuralode = DEF.NeuralODE(neural_dudt, tspan, ODE.Tsit5(); saveat=tsteps)

function predict_neuralode(p)
    return prob_neuralode(initial_conditions_mat, p, st) |> first |> Array
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, dahlquist_sol_ary .- pred)
    return loss
end


# Do not plot by default for the documentation
# Users should change doplot=true to see the plots callbacks
function callback(state, l; doplot = true)
    println(l)
    # plot current prediction against data
    if doplot
        pred = predict_neuralode(state.u)
        plt = Plots.scatter(tsteps, dahlquist_sol_ary[1, end, :]; label = "data")
        Plots.scatter!(plt, tsteps, pred[1, end, :]; label = "prediction")
        display(Plots.plot(plt))
    end
    return false
end

pinit = ComponentArray(p)
adtype = OPT.AutoZygote() # This does *not* mean zygote is used for the NeuralODE

# Run the callback once for initial plot
callback((; u = pinit), loss_neuralode(pinit); doplot = true)

optf = OPT.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = OPT.OptimizationProblem(optf, pinit)

result_neuralode = OPT.solve(
    optprob, OPO.Adam(0.05); callback = callback, maxiters = 300
)


1 Like