Understanding the Lux.jl + ReverseDiff TrackedArray Warning in UDEs

Hello everyone,

I’m working with Universal Differential Equations (UDEs) using Lux.jl and Optimization.jl, and I consistently encounter the following warning when the training process begins. This also happens when running the official SciML tutorial on Missing Physics.

┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).

│ 1. If this was not the desired behavior overload the dispatch on m.

│ 2. This might have performance implications. Check which layer was causing this problem using Lux.Experimental.@debug_mode.
└ @ LuxCoreArrayInterfaceReverseDiffExt C:\Users\Gustavo.julia\packages\LuxCore\Av7WJ\ext\LuxCoreArrayInterfaceReverseDiffExt.jl:9

I understand this is an automatic data type correction happening within the automatic differentiation pass. I’ve also seen it mentioned in other posts, but without a clear conclusion on its impact:

  • Error when a neural ODE is implemented
  • Parameters of the neural network not updating after training in a neural ODE problem

(Julia discourse forum is preventing me to post with the links)

Here is a minimal code example that reproduces the warning:

# MWE to reproduce the Lux.jl + ReverseDiff.jl warning
using OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationOptimisers, Statistics
using Lux, ComponentArrays, Random, Zygote

# 1. Setup and Data Generation
rng = Random.default_rng()
p_true = [1.3, 0.9, 0.8, 1.8] # alpha, beta, gamma, delta
u0 = Float64[0.44249296, 4.6280594]
tspan = (0.0, 3.0)
tsteps = range(tspan[1], tspan[2], length = 100)

function lotka_volterra!(du, u, p, t)
    α, β, δ, γ  = p
    du[1] = α * u[1] - β * u[1] * u[2]
    du[2] = -γ * u[2] + δ * u[1] * u[2]
end

prob = ODEProblem(lotka_volterra!, u0, tspan, p_true)
data = Array(solve(prob, Tsit5(); saveat = tsteps))

# 2. UDE Definition
const nn = Lux.Chain(Lux.Dense(2, 64, sigmoid), Lux.Dense(64, 2))
p, st = Lux.setup(rng, nn)

# 
function known_dynamics(u, p_true)
    dx = p_true[1] * u[1]
    dy = -p_true[4] * u[2]
    return [dx, dy]
end

function ude_dynamics!(du, u, p, t)
    known = known_dynamics(u, p_true)
    learned = nn(u, p, st)[1] 
    du .= known .+ learned
end

prob_nn = ODEProblem(ude_dynamics!, u0, tspan, p)

# 3. Loss and Optimization
function predict(θ)
    _prob = remake(prob_nn, p = θ)
    Array(solve(_prob, Vern7(), saveat = tsteps,
                sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
end

loss(θ) = mean(abs2, data .- predict(θ))

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

# The warning appears when this line is executed
res = Optimization.solve(optprob, Adam(0.01), maxiters = 10)

My questions are:

  1. What is the practical performance impact of this warning? Is it negligible for most UDE problems, or can it cause a significant slowdown?
  2. Is there a recommended coding pattern to avoid this conversion? For example, should the predict or loss function be structured differently to prevent the warning from appearing?

Thanks for any insights!