Nested and different AD methods altogether: How to add AD calculations inside my loss function when using neural differential equations?

Hi all,

I am implementing regularization penalties inside Universal Differential Equations (also applicable to Physics-Informed neural networks) where I need to differentiate a (loss) function that includes in its calculation a gradient. For an UDE, consider the case where the outputs of the neural network are differentiated with respect to the input layer and added to the loss function (this will represent a physical derivative that I am interested in constraining for some application). A second (computational) derivative with respect to the weights of the NN needs to be computed on top of this for optimizing the weights of the NN.

Consider the simple example based on the Lotka-Volterra equations from the SciML tutorial with the following (slightly modified) loss function (MWE with all code is provided at the end of this post):

function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # Regularization based on first derivatives 
    steps_reg = collect(0.0:0.1:10.0)
    dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg)
    norm_dUdx = norm.(dUdx).^2.0
    l_reg = sum(norm_dUdx) 
    return l_emp + l_reg
end

The loss function here is the combination of the empirical error and a regularization term involving derivatives of the neural network. This can be easily evaluated, so the following works

loss(p)
# Return: 116666.31176556791

However, when applying the solve step to optimize the weights of the neural network I obtain ERROR: LoadError: Mutating arrays is not supported, but I am somehow confident the problem is coming from mutating arrays in my code. For example. if I change the loss function to not compute this extra gradient, everything works fine:

function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # regularization 
    steps_reg = collect(0.0:0.1:10.0)
    # dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg) # previous code
    dUdx = map(x -> U([x, 1.0], θ, st)[1], steps_reg) # new code
    norm_dUdx = norm.(dUdx).^2.0
    l_reg = sum(norm_dUdx) 
    return l_emp + l_reg
end

A second implementation of this uses finite differences to compute the derivative involved in the regularization.

function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # Regularization computed with finite differences  
    steps_reg = collect(0.0:0.1:10.0)
    Ux = map(x -> (first ∘ U)([x, 1.0], θ, st), steps_reg)
    dUdx = diff(Ux) ./ diff(steps_reg)
    l_reg = sum(norm.(dUdx).^2.0)    
    return l_emp + l_reg
end

This of course works, although I noticed that it is slower than expected (especially compared to the adjoint differentiation involved in the numerical solver). However, this is far from being the ideal solution to the problem, since I would like to be able to compute derivatives inside my loss function with some AD tool.

The big picture problem. I think the really interesting thing here is how different differentiation methods interact with each other. This is also relevant for physics-informed neural networks, where most implementations I have seen rely purely on reverse AD to compute higher-order derivatives (both with respect to input layer and weights). However, this is quite inefficient and scales exponentially with the order of the derivative. On the other side, I noticed that NeuralPDE.jl instead uses the methods of lines to discretize derivatives in space (code here), which is quite the same I am doing with this example.

The ideal solution. I would like to have something that allows me to compute (physical) derivatives using some form of AD (for derivatives with respect to input layers of a NN, probably a forward method) and then be able to perform gradient-based optimization on top of it using reverse AD. This will push very interesting SciML applications with physical regularization (I have a few examples in my own research) and it is becoming increasingly important for deep learning methods used to solve differential equations.

I would appreciate if someone had some insights on this, both in

  1. Do you have an idea how to compute the gradients in the loss function using some form of AD methods for the regularization; and
  2. Big picture perspectives of the integration between different AD tools in real SciML problems (e.g., I would like to apply reverse AD on top of something computing a forward derivative using dual numbers).

Thanks!

Complete minimal working example:

using Pkg; Pkg.activate(".")

# SciML Tools
using OrdinaryDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using LinearAlgebra, Statistics
using ComponentArrays, Lux, Zygote #, StableRNGs

# Set a random seed for reproducible behaviour
using Random
rng = Random.default_rng()
Random.seed!(rng, 666)

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]
end

# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

# Add noise in terms of the mean
X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))

rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
const U = Lux.Chain(Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
              Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
const _st = st

# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, _st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]
end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

"""
Works but it is clearly a little bit slow! 
"""
# function loss(θ)
#     # Empirical error
#     X̂ = predict(θ)
#     l_emp = mean(abs2, Xₙ .- X̂)
#     # regularization 
#     steps_reg = collect(0.0:0.1:10.0)
#     Ux = map(x -> (first ∘ U)([x, 1.0], θ, st), steps_reg)
#     dUdx = diff(Ux) ./ diff(steps_reg)
#     l_reg = sum(norm.(dUdx).^2.0)    
#     return l_emp + l_reg
# end

"""
Need to solve problem with mutating arrays!
"""
function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # regularization 
    steps_reg = collect(0.0:0.1:10.0)
    dUdx = map(x -> Zygote.jacobian(first ∘ U, [x, 1.0], θ, st)[1], steps_reg)
    # dUdx = map(x -> U([x, 1.0], θ, st)[1], steps_reg)
    norm_dUdx = norm.(dUdx).^2.0
    l_reg = sum(norm_dUdx) 
    return l_emp + l_reg
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
# adtype = Optimization.AutoReverseDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 500)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
3 Likes

Forward differentiation here does not break but ignores the gradients with respect to the derivative term,

dUdx = map(x -> ForwardDiff.jacobian(x -> U([x[1], 1.0], θ, st)[1], [x]), steps_reg)

gives

Warning: `ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
│ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).

@ChrisRackauckas I noticed this is pretty much the same problem reported in the post Gradient of Gradient in Zygote but here I am interested in reverse-over-forward differentiation. Also a similar thread in Issue with Zygote ober ForwardDiff-derivative. However, it is not clear for me what is the recommended solution for this cases, if there is any yet. I noticed the posts are a little bit old, so maybe some of their contents may be outdated. Do I need to define a new rrule() for this problem in order to make this work?

Have you made any progress in your search? I’m encountering a similar problem, using ForwardDiff inside the loss and getting the same warnings.

You can use BatchedRoutines.jl for this, replace ForwardDiff.jacobian with batched_jacobian(AutoForwardDiff(),....). See BatchedRoutines.jl/test/autodiff_tests.jl at b1f3f8fde4e07df1ea45fc753a50b2de7ac20c9e · LuxDL/BatchedRoutines.jl · GitHub as an example. Documentation for this is non-existent and the package is currently unregistered (planning to register it around starting of Summer semester), but we are currently using this for some internal SciML projects for exactly the usecase you want.

I think this link might point to another package with the same name?

oops yeah GitHub - LuxDL/BatchedRoutines.jl: Routines for batching regular code and make them fast!. it seems that name is already taken, might have to think of a different one

Thank you @avikpal for your response!

I was able to make BatchedRoutines.jl to work with both ForwardDiff modes:

using BatchedRoutines

function loss(θ)
    # Empirical error
    X̂ = predict(θ)
    l_emp = mean(abs2, Xₙ .- X̂)
    # regularization 
    steps_reg = collect(0.0:0.1:10.0)
    dUdx = map(x -> batched_jacobian(AutoForwardDiff(), x -> U([x[1], 1.0], θ, st)[1], [1.0])[1], steps_reg) # this works with adtype=AutoForwardDiff
    norm_dUdx = norm.(dUdx).^2.0
    l_reg = sum(norm_dUdx) 
    return l_emp + l_reg
end
...
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

However, I noticed that the backends for AutoReverseDiff() and AutoZygote() are not yet implemented. Am I right? Is there any way of doing some reverse AD here instead of Forward + Forward?

In any case, this is great progress! At least now there is a way to compute these double-gradients.

For jacobians only ForwardDiff and FiniteDiff are implemented. If you use forwarddiff over batched_jacobian it does forward over forward which is extremely efficient for smaller inputs.

For larger inputs (neural network parameters), you would want to do reverse over ForwardDiff. If BatchedRoutines.jl sees a reverse mode call on the outside (like Optimization.AutoZygote), it will automatically switch internal ForwardDiff call to become a Forward over Reverse Mode without any intervention.

But if you want reverse over reverse, that tends to not give much benefits and is not well supported by most julia AD tools (enzyme might work here but I am not too sure).

(I will add reverse mode jacobians eventually for the m → n case where n << m, it is just a low priority)

1 Like

You can test which backends work well with one another by using DifferentiationInterface.jl with the SecondOrder struct

Following my conversation with @avikpal and based on the nice example in the Lux Documentation, I updated this example that now works with nested AD. Here is the complete example using the Lotka-Voltera example from the SciML tutorial.

Here is the code where regularization is imposed in the loss function computing an extra gradient using ForwardDiff and then computing on top of this the gradient with respect to the parameters of the neural network using Zygote.

using OrdinaryDiffEq, SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using LinearAlgebra, Statistics
using ComponentArrays, Lux, Zygote, StableRNGs, ForwardDiff

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]
end

# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(StableRNG(89), 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)

# Add noise in terms of the mean
X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(StableRNG(1), eltype(X), size(X))

# plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
# scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])

# We denine the neural network using Lux
rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
U = Lux.Chain(
        Lux.Dense(2, 5, rbf), 
        Lux.Dense(5, 5, rbf), 
        Lux.Dense(5, 5, rbf),
        Lux.Dense(5, 2)
    )

# Get the initial parameters and state variables of the model
p, st = Lux.setup(StableRNG(65), U)
p = ComponentArray(p)  # needs to be an AbstractArray for most jacobian functions

# const _st = st
# const nn_model = StatefulLuxLayer{true}(U, nothing, st)

# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    smodel = StatefulLuxLayer{true}(U, p, st)
    û = smodel(u)
    du[1] =  p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]
end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

# Code comming from Lux example on nested AD

function loss(model, ps, st)

    # Compute predicions using model parameters
    X_pred = predict(ps)
    loss_emp = mean(abs2, Xₙ .- X_pred)

    # Make it a stateful layer
    smodel = StatefulLuxLayer{true}(U, ps, st)
    
    J = ForwardDiff.jacobian(smodel, Xₙ)
    loss_reg = 0.01f0 * abs2(norm(J))
    return loss_emp + loss_reg
end


ℓ = loss(U, p, st)
println("Loss with regularization involving Jacobian calculation: ", ℓ)

# We see we can compute the gradient 
_, ∂p,  _ = Zygote.gradient(loss, U, p, st)


loss(ps) = loss(U, ps, st)
losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
# adtype = Optimization.AutoReverseDiff()
# adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p)

res = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 500)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

and the respective Project.toml file with Lux=1.0.5:

name = "ExampleNestedAD"
uuid = "f04fbdf8-173e-44da-8764-4928a0487dc3"
authors = ["Facundo Sapienza <fsapienza@berkeley.edu>"]
version = "0.1.0"

[deps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Lux = "1.0"

I will be playing in the future with performance improvements on this (and probably posting them as I go thought it), but feel free to suggest any cool feature here that can improve performance with nested AD :wink:

1 Like