I am running into issues when trying to use reverse-mode automatic differentiation via Zygote when trying to solve families of
Given u(t), we can train a neural ODE to find the underlying ODE du/dt = f(u, t).
For example, consider the dahlquist equation du/dt = lambda*u, u(0)=u0. We replace the righthand side by a neural network f(u,t). Then for a variety of initial conditions, we numerically simulate du/dt = lambda*u to get u(t). To train the neural ODE, for each initial condition we simulate du/dt = f(u,t) to get a prediction u_pred(t), and perform a gradient-based search using |u_pred(t) - u(t)| as the loss-function. In other words, the initial conditions are our features, and the resulting u(t) are our labels. Because we may have many parameters in our neural network, it is important to be able to use reverse-mode AD to compute the gradient.
Here is an example like that from the SciML documentation. They use OrdinaryDiffEq.remake to change the initial condition of the neural ODE problem.
Instead of varying the initial condition, I want to vary lambda (or both). I.e. I want to find du/dt = f(u,t,lambda), where values of lambda are our features, and again the various u(t) are our labels (in my real work lambda will be some control vector that changes du/dt in a more complicated way).
However, I am having trouble using Zygote to compute gradients of the loss function, since the neural network f is parameterized by both lambda and the neural network weights, but only the neural network weights are trainable, and the neural network weights (as used by Lux.jl) are in NamedTuple format, which causes issues for Zygote. To work around this, I tried using ComponentArrays.jl, but that has caused other issues, seemingly because when I gather lambda and the neural network weights to pass to f, Zygote is unable to infer the relationship between the gathered ComponentArray and the input ComponentArray (which consists only of the neural network weights).
In the following MWE, test_grad is (nothing,). I would appreciate any advice on how I could solve the problem with a similar setup, but computing the gradient correctly.
import OrdinaryDiffEq as ODE
import Optimization as OPT
import OptimizationOptimisers as OPO
import Lux
import Zygote
import SciMLSensitivity
using Random: MersenneTwister
using ComponentArrays: ComponentArray
sensealg = SciMLSensitivity.BacksolveAdjoint(autojacvec = SciMLSensitivity.ZygoteVJP())
# Set up dahlquist equation, solve an ensmble of problems with different λ
u0 = 1.0f0
tspan = (0.0f0, 1.0f0)
saveat = 0.1f0
dummy_λ = -1.0f0
"""
Compute du/dt for the dahlquist test equation
du/dt = λ * u
"""
function dudt_dahlquist(u, p, t)
λ = p
return λ * u
end
prob_dahlquist_template = ODE.ODEProblem(dudt_dahlquist, u0, tspan, dummy_λ)
num_training_samples = 10
# Randomly sampled λ values in [-0.5, 0.5]
training_λs = 0.5f0 .- rand(MersenneTwister(0), Float32, num_training_samples)
function dahlquist_prob_func(prob, i, repeat)
return ODE.remake(prob, p = training_λs[i])
end
ensemble_prob = ODE.EnsembleProblem(prob_dahlquist_template, prob_func = dahlquist_prob_func)
solution_ensemble = ODE.solve(
ensemble_prob,
ODE.Tsit5(),
sensealg = sensealg,
abstol = 1e-6,
reltol = 1e-6,
saveat = saveat,
trajectories = num_training_samples, # Determines how many problems to solve in ensemble
)
ensemble_u_true = Array(solution_ensemble)
# Set up the neural ODE to learn dahlquist equation
f_NN = Lux.Chain(
Lux.Dense(3, 10, tanh), # 3 inputs: u, t, λ
Lux.Dense(10, 1)
)
f_NN_weights_init, _st = Lux.setup(MersenneTwister(1), f_NN)
function dudt_dahlquist_NN(u, p, t)
NN_input = vcat(u, t, p.dahlquist_params)
NN_val, _ = Lux.apply(f_NN, NN_input, p.NN_weights, _st)
val = NN_val[1]
return val
end
dummy_params = ComponentArray(dahlquist_params=dummy_λ, NN_weights=f_NN_weights_init)
prob_dahlquist_NN_template = ODE.ODEProblem(dudt_dahlquist_NN, u0, tspan, dummy_params)
function loss_function(trainable_params)
NN_weights = trainable_params.NN_weights
function dahlquist_NN_prob_func(prob, i, repeat)
p = ComponentArray(dahlquist_params=training_λs[i], NN_weights=NN_weights)
return ODE.remake(prob, p = p)
end
ensemble_prob_NN = ODE.EnsembleProblem(prob_dahlquist_NN_template, prob_func = dahlquist_NN_prob_func)
ensemble_sol_pred = ODE.solve(
ensemble_prob_NN,
ODE.Tsit5(),
ODE.EnsembleSerial(),
abstol = 1e-6,
reltol = 1e-6,
saveat = saveat,
sensealg = sensealg,
trajectories = num_training_samples, # Determines how many problems to solve in ensemble
)
ensemble_u_pred = Array(ensemble_sol_pred) # Ensemble solutions are concatenated along 3rd dimension
loss_val = sum(abs2, ensemble_u_pred .- ensemble_u_true) / length(ensemble_u_true)
return loss_val
end
test_grad = Zygote.gradient(loss_function, dummy_params)