Dear all,
I am trying to use a Dropout layer in the NN component of a UDE for parameter estimation. It does work if I use a setup based on Flux. However, if I switch to Lux I get an error message too long to be displayed fully in VSC. I based my code on this and this, so with Lux the implementation is as follows:
# SciML Tools
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity #, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL
# Standard Libraries
using LinearAlgebra, Statistics
# External Libraries
using ComponentArrays, Lux #, Zygote, Plots
# Set a random seed for reproducible behaviour
using Random
rng = Random.default_rng()
Random.seed!(2345)
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α * u[1] - β * u[2] * u[1]
du[2] = γ * u[1] * u[2] - δ * u[2]
end
# Define the experimental parameter
tspan = (0.0, 5.0)
u0 = 5.0f0 * rand(rng, 2)
p_ = [1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0, tspan, p_)
solution = solve(prob, Vern7(), abstol = 1e-12, reltol = 1e-12, saveat = 0.25)
# Add noise in terms of the mean
X = Array(solution)
t = solution.t
x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude * x̄) .* randn(rng, eltype(X), size(X))
rbf(x) = exp.(-(x .^ 2))
# Multilayer FeedForward
U = Lux.Chain(Lux.Dense(2, 5, rbf),
Lux.Dense(5, 5, rbf),
Lux.Dense(5, 5, rbf),
Lux.Dropout(0.3),
Lux.Dense(5, 2))
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
# Define the hybrid model
function ude_dynamics!(du, u, p, t, p_true)
û = U(u, p, st)[1] # Network prediction
du[1] = p_true[1] * u[1] + û[1]
du[2] = -p_true[4] * u[2] + û[2]
end
# Closure with the known parameter
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p_)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, Xₙ[:, 1], tspan, p)
function predict(θ, X = Xₙ[:, 1], T = t)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6))
end
function loss(θ)
X̂ = predict(θ)
mean(abs2, Xₙ .- X̂)
end
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(), callback = callback, maxiters = 10)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
If one deletes Lux.Dropout(0.3) line the code runs without any problems. If anyone has an idea how to fix this, I’d appreciate it!