Nested and different AD methods altogether: How to add AD calculations inside my loss function when using neural differential equations?

Hi all,

I am implementing regularization penalties inside Universal Differential Equations (also applicable to Physics-Informed neural networks) where I need to differentiate a (loss) function that includes in its calculation a gradient. For an UDE, consider the case where the outputs of the neural network are differentiated with respect to the input layer and added to the loss function (this will represent a physical derivative that I am interested in constraining for some application). A second (computational) derivative with respect to the weights of the NN needs to be computed on top of this for optimizing the weights of the NN.

Consider the simple example based on the Lotka-Volterra equations from the SciML tutorial with the following (slightly modified) loss function (MWE with all code is provided at the end of this post):

function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # Regularization based on first derivatives 
    steps_reg = collect(0.0:0.1:10.0)
    dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg)
    norm_dUdx = norm.(dUdx).^2.0
    l_reg = sum(norm_dUdx) 
    return l_emp + l_reg
end

The loss function here is the combination of the empirical error and a regularization term involving derivatives of the neural network. This can be easily evaluated, so the following works

loss(p)
# Return: 116666.31176556791

However, when applying the solve step to optimize the weights of the neural network I obtain ERROR: LoadError: Mutating arrays is not supported, but I am somehow confident the problem is coming from mutating arrays in my code. For example. if I change the loss function to not compute this extra gradient, everything works fine:

function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # regularization 
    steps_reg = collect(0.0:0.1:10.0)
    # dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg) # previous code
    dUdx = map(x -> U([x, 1.0], θ, st)[1], steps_reg) # new code
    norm_dUdx = norm.(dUdx).^2.0
    l_reg = sum(norm_dUdx) 
    return l_emp + l_reg
end

A second implementation of this uses finite differences to compute the derivative involved in the regularization.

function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # Regularization computed with finite differences  
    steps_reg = collect(0.0:0.1:10.0)
    Ux = map(x -> (first ∘ U)([x, 1.0], θ, st), steps_reg)
    dUdx = diff(Ux) ./ diff(steps_reg)
    l_reg = sum(norm.(dUdx).^2.0)    
    return l_emp + l_reg
end

This of course works, although I noticed that it is slower than expected (especially compared to the adjoint differentiation involved in the numerical solver). However, this is far from being the ideal solution to the problem, since I would like to be able to compute derivatives inside my loss function with some AD tool.

The big picture problem. I think the really interesting thing here is how different differentiation methods interact with each other. This is also relevant for physics-informed neural networks, where most implementations I have seen rely purely on reverse AD to compute higher-order derivatives (both with respect to input layer and weights). However, this is quite inefficient and scales exponentially with the order of the derivative. On the other side, I noticed that NeuralPDE.jl instead uses the methods of lines to discretize derivatives in space (code here), which is quite the same I am doing with this example.

The ideal solution. I would like to have something that allows me to compute (physical) derivatives using some form of AD (for derivatives with respect to input layers of a NN, probably a forward method) and then be able to perform gradient-based optimization on top of it using reverse AD. This will push very interesting SciML applications with physical regularization (I have a few examples in my own research) and it is becoming increasingly important for deep learning methods used to solve differential equations.

I would appreciate if someone had some insights on this, both in

  1. Do you have an idea how to compute the gradients in the loss function using some form of AD methods for the regularization; and
  2. Big picture perspectives of the integration between different AD tools in real SciML problems (e.g., I would like to apply reverse AD on top of something computing a forward derivative using dual numbers).

Thanks!

Complete minimal working example:

using Pkg; Pkg.activate(".")

# SciML Tools
using OrdinaryDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using LinearAlgebra, Statistics
using ComponentArrays, Lux, Zygote #, StableRNGs

# Set a random seed for reproducible behaviour
using Random
rng = Random.default_rng()
Random.seed!(rng, 666)

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]
end

# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

# Add noise in terms of the mean
X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
              Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
const _st = st

# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, _st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]
end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

"""
Works but it is clearly a little bit slow! 
"""
# function loss(θ)
#     # Empirical error
#     X̂ = predict(θ)
#     l_emp = mean(abs2, Xₙ .- X̂)
#     # regularization 
#     steps_reg = collect(0.0:0.1:10.0)
#     Ux = map(x -> (first ∘ U)([x, 1.0], θ, st), steps_reg)
#     dUdx = diff(Ux) ./ diff(steps_reg)
#     l_reg = sum(norm.(dUdx).^2.0)    
#     return l_emp + l_reg
# end

"""
Need to solve problem with mutating arrays!
"""
function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # regularization 
    steps_reg = collect(0.0:0.1:10.0)
    dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg)
    # dUdx = map(x -> U([x, 1.0], θ, st)[1], steps_reg)
    norm_dUdx = norm.(dUdx).^2.0
    l_reg = sum(norm_dUdx) 
    return l_emp + l_reg
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
# adtype = Optimization.AutoReverseDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 500)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
2 Likes

Forward differentiation here does not break but ignores the gradients with respect to the derivative term,

dUdx = map(x -> ForwardDiff.jacobian(x -> U([x[1], 1.0], θ, st)[1], [x]), steps_reg)

gives

Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).

@ChrisRackauckas I noticed this is pretty much the same problem reported in the post Gradient of Gradient in Zygote but here I am interested in reverse-over-forward differentiation. Also a similar thread in Issue with Zygote ober ForwardDiff-derivative. However, it is not clear for me what is the recommended solution for this cases, if there is any yet. I noticed the posts are a little bit old, so maybe some of their contents may be outdated. Do I need to define a new rrule() for this problem in order to make this work?