Hello,
I would like to pass additional parameters to a uODE and avoid the closure relationship.
Here is an example with the lotka-volterra system.
# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL
# Standard Libraries
using LinearAlgebra, Statistics
# External Libraries
using ComponentArrays, Lux, Zygote, Plots, LaTeXStrings, StableRNGs
rng = StableRNG(1111)
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α * u[1] - β * u[2] * u[1]
du[2] = γ * u[1] * u[2] - δ * u[2]
end
# Define the experimental parameter
t_true= LinRange(0.0,5.0,300)
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = t_true)
Xₙ = Array(solution)
t = solution.t
rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
U = Lux.Chain( Lux.Dense(2, 5, rbf),
Lux.Dense(5, 5, rbf),
Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
û = U(u, p, st)[1] # Network prediction
du[1] = p_true[1] * u[1] + û[1]
du[2] = -p_true[4] * u[2] + û[2]
end
# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
function predict(θ, X = Xₙ[:, 1], T = t)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6))
end
function loss(θ)
X̂ = predict(θ)
mean(abs2, Xₙ .- X̂)
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 10 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
@time res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 300)
The output is:
Current loss after 10 iterations: 83557.60178180611
Current loss after 20 iterations: 74198.89503059449
Current loss after 30 iterations: 64595.88341112899
Current loss after 40 iterations: 54859.53190439518
Current loss after 50 iterations: 45346.21378820202
Current loss after 60 iterations: 36610.279113608194
Current loss after 70 iterations: 29419.779445334116
Current loss after 80 iterations: 23800.77339733237
Current loss after 90 iterations: 19367.222694468925
Current loss after 100 iterations: 15861.289470984184
Current loss after 110 iterations: 13100.148799774146
Current loss after 120 iterations: 10911.005882126306
Current loss after 130 iterations: 9166.457697107462
Current loss after 140 iterations: 7763.650392348302
Current loss after 150 iterations: 6622.354885298846
Current loss after 160 iterations: 5681.885681152856
Current loss after 170 iterations: 4897.13955236694
Current loss after 180 iterations: 4234.733638699881
Current loss after 190 iterations: 3670.1218825294704
Current loss after 200 iterations: 3185.1665116268246
Current loss after 210 iterations: 2766.1979936540934
Current loss after 220 iterations: 2402.6724654449445
Current loss after 230 iterations: 2086.2767977928993
Current loss after 240 iterations: 1810.3245428643088
Current loss after 250 iterations: 1569.3399237512847
Current loss after 260 iterations: 1358.7663259512062
Current loss after 270 iterations: 1174.7585727568623
Current loss after 280 iterations: 1014.0327395585466
Current loss after 290 iterations: 873.75608704358
Current loss after 300 iterations: 751.4648729247339
6.475962 seconds (33.13 M allocations: 5.752 GiB, 23.48% gc time, 10.93% compilation time) # Results from 2nd run
The same example with ComponentArrays
to avoid the closure relationship.
rng = StableRNG(1111)
function lotka!(du, u, p, t)
α, β, γ, δ = [1.3, 0.9, 0.8, 1.8]
du[1] = α * u[1] - β * u[2] * u[1]
du[2] = γ * u[1] * u[2] - δ * u[2]
end
# Define the experimental parameter
t_true= LinRange(0.0,5.0,300)
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = zeros(4)
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = t_true)
Xₙ = Array(solution);
t = solution.t;
rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
U = Lux.Chain( Lux.Dense(2, 5, rbf),
Lux.Dense(5, 5, rbf),
Lux.Dense(5, 5, rbf),
Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
ps, st = Lux.setup(rng, U)
p = ComponentArray{Float64}(p=ps, p_true=p_) # Put variables in a componentArray
# Define the hybrid model
function nn_dynamics!(du, u, p, t)
û = U(u, p.p, st)[1] # Network prediction
du[1] = p.p_true[1] * u[1] + û[1]
du[2] = -p.p_true[4] * u[2] + û[2]
end
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
function predict(θ, X = Xₙ[:, 1], T = t)
# p_true can be changed in `remake`
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p=ComponentArray(p=θ, p_true=[1.3, 0.9, 0.8, 1.8]))
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6))
end
function loss(θ)
X̂ = predict(θ)
mean(abs2, Xₙ .- X̂)
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 10 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, p.p)
@time res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 300)
The results are pretty similar:
Current loss after 10 iterations: 83557.60178179643
Current loss after 20 iterations: 74198.89503059752
Current loss after 30 iterations: 64595.88341116582
Current loss after 40 iterations: 54859.531904407406
Current loss after 50 iterations: 45346.213788260924
Current loss after 60 iterations: 36610.27911370193
Current loss after 70 iterations: 29419.779445399872
Current loss after 80 iterations: 23800.773397475943
Current loss after 90 iterations: 19367.22269463254
Current loss after 100 iterations: 15861.289471134369
Current loss after 110 iterations: 13100.148799867218
Current loss after 120 iterations: 10911.00588226756
Current loss after 130 iterations: 9166.457697222619
Current loss after 140 iterations: 7763.650392506924
Current loss after 150 iterations: 6622.354885502134
Current loss after 160 iterations: 5681.885681344036
Current loss after 170 iterations: 4897.139552551027
Current loss after 180 iterations: 4234.733638870311
Current loss after 190 iterations: 3670.121882687468
Current loss after 200 iterations: 3185.1665117731727
Current loss after 210 iterations: 2766.197993783185
Current loss after 220 iterations: 2402.6724655506787
Current loss after 230 iterations: 2086.2767978719853
Current loss after 240 iterations: 1810.3245429143208
Current loss after 250 iterations: 1569.3399237740252
Current loss after 260 iterations: 1358.7663259609906
Current loss after 270 iterations: 1174.7585727624441
Current loss after 280 iterations: 1014.0327395591854
Current loss after 290 iterations: 873.7560870350225
Current loss after 300 iterations: 751.4648729027014
6.268178 seconds (31.68 M allocations: 5.991 GiB, 23.36% gc time) # Results from 2nd run
Maybe I am being too picky, but why are there numerical differences if the results are Float64
and conceptually the same?