Hello everyone,
I’m working with Universal Differential Equations (UDEs) using Lux.jl
and Optimization.jl
, and I consistently encounter the following warning when the training process begins. This also happens when running the official SciML tutorial on Missing Physics.
┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│
│ 1. If this was not the desired behavior overload the dispatch onm
.
│
│ 2. This might have performance implications. Check which layer was causing this problem usingLux.Experimental.@debug_mode
.
└ @ LuxCoreArrayInterfaceReverseDiffExt C:\Users\Gustavo.julia\packages\LuxCore\Av7WJ\ext\LuxCoreArrayInterfaceReverseDiffExt.jl:9
I understand this is an automatic data type correction happening within the automatic differentiation pass. I’ve also seen it mentioned in other posts, but without a clear conclusion on its impact:
- Error when a neural ODE is implemented
- Parameters of the neural network not updating after training in a neural ODE problem
(Julia discourse forum is preventing me to post with the links)
Here is a minimal code example that reproduces the warning:
# MWE to reproduce the Lux.jl + ReverseDiff.jl warning
using OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationOptimisers, Statistics
using Lux, ComponentArrays, Random, Zygote
# 1. Setup and Data Generation
rng = Random.default_rng()
p_true = [1.3, 0.9, 0.8, 1.8] # alpha, beta, gamma, delta
u0 = Float64[0.44249296, 4.6280594]
tspan = (0.0, 3.0)
tsteps = range(tspan[1], tspan[2], length = 100)
function lotka_volterra!(du, u, p, t)
α, β, δ, γ = p
du[1] = α * u[1] - β * u[1] * u[2]
du[2] = -γ * u[2] + δ * u[1] * u[2]
end
prob = ODEProblem(lotka_volterra!, u0, tspan, p_true)
data = Array(solve(prob, Tsit5(); saveat = tsteps))
# 2. UDE Definition
const nn = Lux.Chain(Lux.Dense(2, 64, sigmoid), Lux.Dense(64, 2))
p, st = Lux.setup(rng, nn)
#
function known_dynamics(u, p_true)
dx = p_true[1] * u[1]
dy = -p_true[4] * u[2]
return [dx, dy]
end
function ude_dynamics!(du, u, p, t)
known = known_dynamics(u, p_true)
learned = nn(u, p, st)[1]
du .= known .+ learned
end
prob_nn = ODEProblem(ude_dynamics!, u0, tspan, p)
# 3. Loss and Optimization
function predict(θ)
_prob = remake(prob_nn, p = θ)
Array(solve(_prob, Vern7(), saveat = tsteps,
sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
end
loss(θ) = mean(abs2, data .- predict(θ))
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
# The warning appears when this line is executed
res = Optimization.solve(optprob, Adam(0.01), maxiters = 10)
My questions are:
- What is the practical performance impact of this warning? Is it negligible for most UDE problems, or can it cause a significant slowdown?
- Is there a recommended coding pattern to avoid this conversion? For example, should the
predict
orloss
function be structured differently to prevent the warning from appearing?
Thanks for any insights!