Avoiding closure in uODE with ComponentArrays causes minor numerical differences


I would like to pass additional parameters to a uODE and avoid the closure relationship.

Here is an example with the lotka-volterra system.

# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL

# Standard Libraries
using LinearAlgebra, Statistics

# External Libraries
using ComponentArrays, Lux, Zygote, Plots, LaTeXStrings, StableRNGs

rng = StableRNG(1111)

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]

# Define the experimental parameter
t_true= LinRange(0.0,5.0,300)
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = t_true)

Xₙ = Array(solution)
t = solution.t

rbf(x) = exp.(-(x .^ 2))

# Multilayer FeedForward
U = Lux.Chain(  Lux.Dense(2, 5, rbf), 
                Lux.Dense(5, 5, rbf), 
                Lux.Dense(5, 5, rbf),
                Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, st)[1] # Network prediction
    du[1] = p_true[1] * u[1] + û[1]
    du[2] = -p_true[4] * u[2] + û[2]

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:, 1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6))

function loss(θ)
    X̂ = predict(θ)
    mean(abs2, Xₙ .- X̂)

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    return false

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

@time res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 300)

The output is:

Current loss after 10 iterations: 83557.60178180611
Current loss after 20 iterations: 74198.89503059449
Current loss after 30 iterations: 64595.88341112899
Current loss after 40 iterations: 54859.53190439518
Current loss after 50 iterations: 45346.21378820202
Current loss after 60 iterations: 36610.279113608194
Current loss after 70 iterations: 29419.779445334116
Current loss after 80 iterations: 23800.77339733237
Current loss after 90 iterations: 19367.222694468925
Current loss after 100 iterations: 15861.289470984184
Current loss after 110 iterations: 13100.148799774146
Current loss after 120 iterations: 10911.005882126306
Current loss after 130 iterations: 9166.457697107462
Current loss after 140 iterations: 7763.650392348302
Current loss after 150 iterations: 6622.354885298846
Current loss after 160 iterations: 5681.885681152856
Current loss after 170 iterations: 4897.13955236694
Current loss after 180 iterations: 4234.733638699881
Current loss after 190 iterations: 3670.1218825294704
Current loss after 200 iterations: 3185.1665116268246
Current loss after 210 iterations: 2766.1979936540934
Current loss after 220 iterations: 2402.6724654449445
Current loss after 230 iterations: 2086.2767977928993
Current loss after 240 iterations: 1810.3245428643088
Current loss after 250 iterations: 1569.3399237512847
Current loss after 260 iterations: 1358.7663259512062
Current loss after 270 iterations: 1174.7585727568623
Current loss after 280 iterations: 1014.0327395585466
Current loss after 290 iterations: 873.75608704358
Current loss after 300 iterations: 751.4648729247339
 6.475962 seconds (33.13 M allocations: 5.752 GiB, 23.48% gc time, 10.93% compilation time) # Results from 2nd run

The same example with ComponentArrays to avoid the closure relationship.

rng = StableRNG(1111)

function lotka!(du, u, p, t)
    α, β, γ, δ = [1.3, 0.9, 0.8, 1.8]
    du[1] = α * u[1] - β * u[2] * u[1]
    du[2] = γ * u[1] * u[2] - δ * u[2]

# Define the experimental parameter
t_true= LinRange(0.0,5.0,300)
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = zeros(4)
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = t_true)

Xₙ = Array(solution);
t = solution.t;

rbf(x) = exp.(-(x .^ 2))

# Multilayer FeedForward
U = Lux.Chain(  Lux.Dense(2, 5, rbf), 
                Lux.Dense(5, 5, rbf), 
                Lux.Dense(5, 5, rbf),
                Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
ps, st = Lux.setup(rng, U)

p = ComponentArray{Float64}(p=ps, p_true=p_) # Put variables in a componentArray

# Define the hybrid model
function nn_dynamics!(du, u, p, t)
    û = U(u, p.p, st)[1] # Network prediction
    du[1] = p.p_true[1] * u[1] + û[1]
    du[2] = -p.p_true[4] * u[2] + û[2]

# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)

function predict(θ, X = Xₙ[:, 1], T = t)
    # p_true can be changed in `remake`
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p=ComponentArray(p=θ, p_true=[1.3, 0.9, 0.8, 1.8]))
    Array(solve(_prob, Vern7(), saveat = T,
                abstol = 1e-6, reltol = 1e-6))

function loss(θ)
    X̂ = predict(θ)
    mean(abs2, Xₙ .- X̂)

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 10 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    return false

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p.p)

@time res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 300)

The results are pretty similar:

Current loss after 10 iterations: 83557.60178179643
Current loss after 20 iterations: 74198.89503059752
Current loss after 30 iterations: 64595.88341116582
Current loss after 40 iterations: 54859.531904407406
Current loss after 50 iterations: 45346.213788260924
Current loss after 60 iterations: 36610.27911370193
Current loss after 70 iterations: 29419.779445399872
Current loss after 80 iterations: 23800.773397475943
Current loss after 90 iterations: 19367.22269463254
Current loss after 100 iterations: 15861.289471134369
Current loss after 110 iterations: 13100.148799867218
Current loss after 120 iterations: 10911.00588226756
Current loss after 130 iterations: 9166.457697222619
Current loss after 140 iterations: 7763.650392506924
Current loss after 150 iterations: 6622.354885502134
Current loss after 160 iterations: 5681.885681344036
Current loss after 170 iterations: 4897.139552551027
Current loss after 180 iterations: 4234.733638870311
Current loss after 190 iterations: 3670.121882687468
Current loss after 200 iterations: 3185.1665117731727
Current loss after 210 iterations: 2766.197993783185
Current loss after 220 iterations: 2402.6724655506787
Current loss after 230 iterations: 2086.2767978719853
Current loss after 240 iterations: 1810.3245429143208
Current loss after 250 iterations: 1569.3399237740252
Current loss after 260 iterations: 1358.7663259609906
Current loss after 270 iterations: 1174.7585727624441
Current loss after 280 iterations: 1014.0327395591854
Current loss after 290 iterations: 873.7560870350225
Current loss after 300 iterations: 751.4648729027014

6.268178 seconds (31.68 M allocations: 5.991 GiB, 23.36% gc time) # Results from 2nd run

Maybe I am being too picky, but why are there numerical differences if the results are Float64 and conceptually the same?

for an optimization problem and from a numerical point of view I would say those results are identical, and could be explained by a simple re-ordering of operations caused by the different julia code (or maybe one path uses SIMD and one doesn’t? I’m just guessing here… ), but the code is doing gazillions of floating point operations and you still match to six, seven digits… I’m pretty sure the ADAM optimizers use random numbers to generate guesses so I would be surprised if even runs with the exact same Julia code matched to 16 digits…

since your getting 7 digits, that could perhaps indicate (another guess) that the ADAM optimizer is using Float32 for speed somewhere?

I believe p would be Float32 from here,

so that optimization problem would be in float32

When I do:

julia> typeof(p.p)
ComponentVector{Float64, SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:15, Axis(weight = ViewAxis(1:10, ShapedAxis((5, 2), NamedTuple())), bias = ViewAxis(11:15, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(16:45, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_3 = ViewAxis(46:75, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_4 = ViewAxis(76:87, Axis(weight = ViewAxis(1:10, ShapedAxis((2, 5), NamedTuple())), bias = ViewAxis(11:12, ShapedAxis((2, 1), NamedTuple())))))}}} (alias for ComponentArray{Float64, 1, SubArray{Float64, 1, Array{Float64, 1}, Tuple{UnitRange{Int64}}, true}, Tuple{Axis{(layer_1 = ViewAxis(1:15, Axis(weight = ViewAxis(1:10, ShapedAxis((5, 2), NamedTuple())), bias = ViewAxis(11:15, ShapedAxis((5, 1), NamedTuple())))), layer_2 = ViewAxis(16:45, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_3 = ViewAxis(46:75, Axis(weight = ViewAxis(1:25, ShapedAxis((5, 5), NamedTuple())), bias = ViewAxis(26:30, ShapedAxis((5, 1), NamedTuple())))), layer_4 = ViewAxis(76:87, Axis(weight = ViewAxis(1:10, ShapedAxis((2, 5), NamedTuple())), bias = ViewAxis(11:12, ShapedAxis((2, 1), NamedTuple())))))}}})

Everything appears to be Float64.

If I run optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(ps)) [org version] instead of optprob = Optimization.OptimizationProblem(optf, p.p) the results are exactly the same.
Additionally, are ComponentArrays a good alternative to pass parameters that will change in every predict call for a uODE? The function signature is strict for ODEProblem so I am passing everything through p as a ComponentArray but this comes with allocations overhead


ComponentArrays shouldn’t need to allocate any more. They are just a vector.