Understanding the Lux.jl + ReverseDiff TrackedArray Warning in UDEs

Hello everyone,

I’m working with Universal Differential Equations (UDEs) using Lux.jl and Optimization.jl, and I consistently encounter the following warning when the training process begins. This also happens when running the official SciML tutorial on Missing Physics.

┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).

│ 1. If this was not the desired behavior overload the dispatch on m.

│ 2. This might have performance implications. Check which layer was causing this problem using Lux.Experimental.@debug_mode.
└ @ LuxCoreArrayInterfaceReverseDiffExt C:\Users\Gustavo.julia\packages\LuxCore\Av7WJ\ext\LuxCoreArrayInterfaceReverseDiffExt.jl:9

I understand this is an automatic data type correction happening within the automatic differentiation pass. I’ve also seen it mentioned in other posts, but without a clear conclusion on its impact:

  • Error when a neural ODE is implemented
  • Parameters of the neural network not updating after training in a neural ODE problem

(Julia discourse forum is preventing me to post with the links)

Here is a minimal code example that reproduces the warning:

# MWE to reproduce the Lux.jl + ReverseDiff.jl warning
using OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationOptimisers, Statistics
using Lux, ComponentArrays, Random, Zygote

# 1. Setup and Data Generation
rng = Random.default_rng()
p_true = [1.3, 0.9, 0.8, 1.8] # alpha, beta, gamma, delta
u0 = Float64[0.44249296, 4.6280594]
tspan = (0.0, 3.0)
tsteps = range(tspan[1], tspan[2], length = 100)

function lotka_volterra!(du, u, p, t)
    α, β, δ, γ  = p
    du[1] = α * u[1] - β * u[1] * u[2]
    du[2] = -γ * u[2] + δ * u[1] * u[2]
end

prob = ODEProblem(lotka_volterra!, u0, tspan, p_true)
data = Array(solve(prob, Tsit5(); saveat = tsteps))

# 2. UDE Definition
const nn = Lux.Chain(Lux.Dense(2, 64, sigmoid), Lux.Dense(64, 2))
p, st = Lux.setup(rng, nn)

# 
function known_dynamics(u, p_true)
    dx = p_true[1] * u[1]
    dy = -p_true[4] * u[2]
    return [dx, dy]
end

function ude_dynamics!(du, u, p, t)
    known = known_dynamics(u, p_true)
    learned = nn(u, p, st)[1] 
    du .= known .+ learned
end

prob_nn = ODEProblem(ude_dynamics!, u0, tspan, p)

# 3. Loss and Optimization
function predict(θ)
    _prob = remake(prob_nn, p = θ)
    Array(solve(_prob, Vern7(), saveat = tsteps,
                sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
end

loss(θ) = mean(abs2, data .- predict(θ))

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

# The warning appears when this line is executed
res = Optimization.solve(optprob, Adam(0.01), maxiters = 10)

My questions are:

  1. What is the practical performance impact of this warning? Is it negligible for most UDE problems, or can it cause a significant slowdown?
  2. Is there a recommended coding pattern to avoid this conversion? For example, should the predict or loss function be structured differently to prevent the warning from appearing?

Thanks for any insights!

For in-place ODEs it’s required that Array{TrackedReal} form is used in order to support in-place updates. Lux.jl sees this and then does the array of structs to struct of array transformation, i.e. transforms the input into a TrackedArray, which then makes the linear algebra of the neural networks much faster than tracking the scalar operations. The other parts of the UDE are actually scalar operations, so there is no performance hit in the known_dynamics. So the total effect of this is that you end up with a pretty optimal code, hitting the struct of array form in the linear algebra and array of structs in the other part. So you can safely ignore the warning here.

Note that in most cases it would make sense for this warning, but some of the specific details of how UDEs interact with AD makes it actually fine in this case.

This is pretty inherent to the ReverseDiffVJP (and TrackerVJP) formulations, since it’s a requirement for the in-place ODE form to work. You can use an out-of-place ODE definition, but you’d still end up scalarizing the dynamics part. Really the solution is EnzymeVJP or MooncakeVJP which do not use a type-based representation so there is no AoS vs SoA form to handle.

Thanks for the clarification, Chris!