Hi all,
I am implementing regularization penalties inside Universal Differential Equations (also applicable to Physics-Informed neural networks) where I need to differentiate a (loss) function that includes in its calculation a gradient. For an UDE, consider the case where the outputs of the neural network are differentiated with respect to the input layer and added to the loss function (this will represent a physical derivative that I am interested in constraining for some application). A second (computational) derivative with respect to the weights of the NN needs to be computed on top of this for optimizing the weights of the NN.
Consider the simple example based on the Lotka-Volterra equations from the SciML tutorial with the following (slightly modified) loss function (MWE with all code is provided at the end of this post):
function loss(θ)
# Empirical error
X̂ = predict(θ)
l_emp = mean(abs2, Xₙ .- X̂)
# Regularization based on first derivatives
steps_reg = collect(0.0:0.1:10.0)
dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg)
norm_dUdx = norm.(dUdx).^2.0
l_reg = sum(norm_dUdx)
return l_emp + l_reg
end
The loss function here is the combination of the empirical error and a regularization term involving derivatives of the neural network. This can be easily evaluated, so the following works
loss(p)
# Return: 116666.31176556791
However, when applying the solve
step to optimize the weights of the neural network I obtain ERROR: LoadError: Mutating arrays is not supported
, but I am somehow confident the problem is coming from mutating arrays in my code. For example. if I change the loss function to not compute this extra gradient, everything works fine:
function loss(θ)
# Empirical error
X̂ = predict(θ)
l_emp = mean(abs2, Xₙ .- X̂)
# regularization
steps_reg = collect(0.0:0.1:10.0)
# dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg) # previous code
dUdx = map(x -> U([x, 1.0], θ, st)[1], steps_reg) # new code
norm_dUdx = norm.(dUdx).^2.0
l_reg = sum(norm_dUdx)
return l_emp + l_reg
end
A second implementation of this uses finite differences to compute the derivative involved in the regularization.
function loss(θ)
# Empirical error
X̂ = predict(θ)
l_emp = mean(abs2, Xₙ .- X̂)
# Regularization computed with finite differences
steps_reg = collect(0.0:0.1:10.0)
Ux = map(x -> (first ∘ U)([x, 1.0], θ, st), steps_reg)
dUdx = diff(Ux) ./ diff(steps_reg)
l_reg = sum(norm.(dUdx).^2.0)
return l_emp + l_reg
end
This of course works, although I noticed that it is slower than expected (especially compared to the adjoint differentiation involved in the numerical solver). However, this is far from being the ideal solution to the problem, since I would like to be able to compute derivatives inside my loss function with some AD tool.
The big picture problem. I think the really interesting thing here is how different differentiation methods interact with each other. This is also relevant for physics-informed neural networks, where most implementations I have seen rely purely on reverse AD to compute higher-order derivatives (both with respect to input layer and weights). However, this is quite inefficient and scales exponentially with the order of the derivative. On the other side, I noticed that NeuralPDE.jl
instead uses the methods of lines to discretize derivatives in space (code here), which is quite the same I am doing with this example.
The ideal solution. I would like to have something that allows me to compute (physical) derivatives using some form of AD (for derivatives with respect to input layers of a NN, probably a forward method) and then be able to perform gradient-based optimization on top of it using reverse AD. This will push very interesting SciML applications with physical regularization (I have a few examples in my own research) and it is becoming increasingly important for deep learning methods used to solve differential equations.
I would appreciate if someone had some insights on this, both in
- Do you have an idea how to compute the gradients in the loss function using some form of AD methods for the regularization; and
- Big picture perspectives of the integration between different AD tools in real SciML problems (e.g., I would like to apply reverse AD on top of something computing a forward derivative using dual numbers).
Thanks!
Complete minimal working example:
using Pkg; Pkg.activate(".")
# SciML Tools
using OrdinaryDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using LinearAlgebra, Statistics
using ComponentArrays, Lux, Zygote #, StableRNGs
# Set a random seed for reproducible behaviour
using Random
rng = Random.default_rng()
Random.seed!(rng, 666)
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α * u[1] - β * u[2] * u[1]
du[2] = γ * u[1] * u[2] - δ * u[2]
end
# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)
# Add noise in terms of the mean
X = Array(solution)
t = solution.t
x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))
rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
const _st = st
# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
û = U(u, p, _st)[1] # Network prediction
du[1] = p_true[1] * u[1] + û[1]
du[2] = -p_true[4] * u[2] + û[2]
end
# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
function predict(θ, X = Xₙ[:, 1], T = t)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6,
sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end
"""
Works but it is clearly a little bit slow!
"""
# function loss(θ)
# # Empirical error
# X̂ = predict(θ)
# l_emp = mean(abs2, Xₙ .- X̂)
# # regularization
# steps_reg = collect(0.0:0.1:10.0)
# Ux = map(x -> (first ∘ U)([x, 1.0], θ, st), steps_reg)
# dUdx = diff(Ux) ./ diff(steps_reg)
# l_reg = sum(norm.(dUdx).^2.0)
# return l_emp + l_reg
# end
"""
Need to solve problem with mutating arrays!
"""
function loss(θ)
# Empirical error
X̂ = predict(θ)
l_emp = mean(abs2, Xₙ .- X̂)
# regularization
steps_reg = collect(0.0:0.1:10.0)
dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg)
# dUdx = map(x -> U([x, 1.0], θ, st)[1], steps_reg)
norm_dUdx = norm.(dUdx).^2.0
l_reg = sum(norm_dUdx)
return l_emp + l_reg
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 10 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
# adtype = Optimization.AutoReverseDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 500)
println("Training loss after $(length(losses)) iterations: $(losses[end])")