UDE: Estimating multiple parameters with NN and choosing best optimizers

Hello, I am currently trying to implement a UDE to help me create a formulation for 2 parameters in a lung model. The code is based on the tutorial https://docs.sciml.ai/Overview/dev/showcase/missing_physics/#Visualizing-the-Trained-UDE.

I was successful in implementing the UDE where only 1 parameter was approximated with a NN (1 input, 1 hidden layer, 1 output) and reached a loss in the e-5 range:

function lung_dynamics!(du, u, p, t, p_true)
    volume = u
    pressure = pressure_interp(t)
    flow = flow_interp(t)

    û = U(u, p, _st)[1]  # Network prediction (elastance)

    flow = (pressure - u[1] * û[1]) / p_true[2]
    du[1] = flow
end

But once I extended it to 2 parameters (to replace p_true[2]), the loss stopped decreasing after reaching ~2766 for the Adam optimizer, which is too large.

const U = Lux.Chain(
    Lux.Dense(1, 10, sigmoid),
    Lux.Dense(10, 10, sigmoid),  #maybe try ReLu
    Lux.Dense(10, 2)
)
rng = StableRNG(1111)
p, st = Lux.setup(rng, U)
const _st = st

function lung_dynamics!(du, u, p, t)
    volume = u
    pressure = pressure_interp(t)
    flow = flow_interp(t)
    parameters = U(u, p, _st)[1] # Network prediction 
    flow = (pressure - u[1] * parameters[1]) / parameters[2]
    du[1] = flow

end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = lung_dynamics!(du, u, p, t)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, initial_condition, tspan, p)

function predict(θ, X = initial_condition, T = tspan)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = adjusted_t_data, #adjusted_t_data
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss(θ)
    X̂ = predict(θ)
    mean(abs2, measured_volume_data .- X̂[1,:])
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 100 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

# Training

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(1e-4), callback = callback, maxiters = 1000)
optprob2 = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(res1.u))
res2 = Optimization.solve(optprob2, LBFGS(linesearch = BackTracking()), callback = callback, maxiters = 1000)

# Rename the best candidate
p_trained = res2.u

The LBFGS also does not minimize the error any further.

Now my questions at this point are:

Is there a problem with the way in which I have defined the 2 parameters that I want to approximate in the UDE?

Or is it more a problem of using the wrong optimizers? I have tried sigmoid, rbf and relu as activation functions as well as BFGS for the optimizer.

These are the packages being used:

Status `C:\Users\minyo\.julia\environments\v1.10\Project.toml`
  [336ed68f] CSV v0.10.14
  [b0b7db55] ComponentArrays v0.15.13
  [2445eb08] DataDrivenDiffEq v1.4.1
  [5b588203] DataDrivenSparse v0.1.2
  [a93c6f00] DataFrames v1.6.1
⌃ [82cc6244] DataInterpolations v5.0.0
  [1130ab10] DiffEqParamEstim v2.2.0
  [0c46a032] DifferentialEquations v7.13.0
⌃ [31c24e10] Distributions v0.25.108
  [a98d9a8b] Interpolations v0.15.1
  [d3d80556] LineSearches v7.2.0
⌃ [b2108857] Lux v0.5.47
⌃ [961ee093] ModelingToolkit v9.15.0
⌃ [7f7a1694] Optimization v3.24.3
  [36348300] OptimizationOptimJL v0.3.2
  [42dfb2eb] OptimizationOptimisers v0.2.1
  [500b13db] OptimizationPolyalgorithms v0.2.1
⌃ [1dea7af3] OrdinaryDiffEq v6.74.1
  [91a5bcdd] Plots v1.40.4
⌃ [1ed8b502] SciMLSensitivity v7.56.2
  [860ef19b] StableRNGs v1.0.2
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
  [10745b16] Statistics v1.10.0

I would appreciate any input, thank you for your time :slight_smile:

Did you try all of the other tricks like multiple shooting, PEM, etc? The loss function really matters for removing local minima.