UDE: Estimating multiple parameters with NN and choosing best optimizers

Hello, I am currently trying to implement a UDE to help me create a formulation for 2 parameters in a lung model. The code is based on the tutorial https://docs.sciml.ai/Overview/dev/showcase/missing_physics/#Visualizing-the-Trained-UDE.

I was successful in implementing the UDE where only 1 parameter was approximated with a NN (1 input, 1 hidden layer, 1 output) and reached a loss in the e-5 range:

function lung_dynamics!(du, u, p, t, p_true)
    volume = u
    pressure = pressure_interp(t)
    flow = flow_interp(t)

    û = U(u, p, _st)[1]  # Network prediction (elastance)

    flow = (pressure - u[1] * û[1]) / p_true[2]
    du[1] = flow
end

But once I extended it to 2 parameters (to replace p_true[2]), the loss stopped decreasing after reaching ~2766 for the Adam optimizer, which is too large.

const U = Lux.Chain(
    Lux.Dense(1, 10, sigmoid),
    Lux.Dense(10, 10, sigmoid),  #maybe try ReLu
    Lux.Dense(10, 2)
)
rng = StableRNG(1111)
p, st = Lux.setup(rng, U)
const _st = st

function lung_dynamics!(du, u, p, t)
    volume = u
    pressure = pressure_interp(t)
    flow = flow_interp(t)
    parameters = U(u, p, _st)[1] # Network prediction 
    flow = (pressure - u[1] * parameters[1]) / parameters[2]
    du[1] = flow

end

# Closure with the known parameter
nn_dynamics!(du, u, p, t) = lung_dynamics!(du, u, p, t)
# Define the problem
prob_nn = ODEProblem(nn_dynamics!, initial_condition, tspan, p)

function predict(θ, X = initial_condition, T = tspan)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = adjusted_t_data, #adjusted_t_data
                abstol = 1e-6, reltol = 1e-6,
                sensealg=QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))))
end

function loss(θ)
    X̂ = predict(θ)
    mean(abs2, measured_volume_data .- X̂[1,:])
end

losses = Float64[]

callback = function (p, l)
    push!(losses, l)
    if length(losses) % 100 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

# Training

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(1e-4), callback = callback, maxiters = 1000)
optprob2 = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(res1.u))
res2 = Optimization.solve(optprob2, LBFGS(linesearch = BackTracking()), callback = callback, maxiters = 1000)

# Rename the best candidate
p_trained = res2.u

The LBFGS also does not minimize the error any further.

Now my questions at this point are:

Is there a problem with the way in which I have defined the 2 parameters that I want to approximate in the UDE?

Or is it more a problem of using the wrong optimizers? I have tried sigmoid, rbf and relu as activation functions as well as BFGS for the optimizer.

These are the packages being used:

Status `C:\Users\minyo\.julia\environments\v1.10\Project.toml`
  [336ed68f] CSV v0.10.14
  [b0b7db55] ComponentArrays v0.15.13
  [2445eb08] DataDrivenDiffEq v1.4.1
  [5b588203] DataDrivenSparse v0.1.2
  [a93c6f00] DataFrames v1.6.1
⌃ [82cc6244] DataInterpolations v5.0.0
  [1130ab10] DiffEqParamEstim v2.2.0
  [0c46a032] DifferentialEquations v7.13.0
⌃ [31c24e10] Distributions v0.25.108
  [a98d9a8b] Interpolations v0.15.1
  [d3d80556] LineSearches v7.2.0
⌃ [b2108857] Lux v0.5.47
⌃ [961ee093] ModelingToolkit v9.15.0
⌃ [7f7a1694] Optimization v3.24.3
  [36348300] OptimizationOptimJL v0.3.2
  [42dfb2eb] OptimizationOptimisers v0.2.1
  [500b13db] OptimizationPolyalgorithms v0.2.1
⌃ [1dea7af3] OrdinaryDiffEq v6.74.1
  [91a5bcdd] Plots v1.40.4
⌃ [1ed8b502] SciMLSensitivity v7.56.2
  [860ef19b] StableRNGs v1.0.2
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
  [10745b16] Statistics v1.10.0

I would appreciate any input, thank you for your time :slight_smile:

Did you try all of the other tricks like multiple shooting, PEM, etc? The loss function really matters for removing local minima.

Thanks for the suggestions. I am currently implementing the other tricks, but I think it might also be more of a formulation problem.

In my case, I am trying to find dynamical formulations of all the parameters in the equation, meaning that I no longer have any known parameters. As a test, I modified the Lotka-Volterra examples in the tutorial Automatically Discover Missing Physics by Embedding Machine Learning into Differential Equations · Overview of Julia's SciML to also estimate all 4 parameters (α, β, γ, δ) instead of the terms -βxy and δxy.

# Define the modified hybrid model
function ude_dynamics!(du, u, p, t, p_true)
    û = U(u, p, _st)[1] # Network prediction
    du[1] = û[1] * u[1] + û[1] * u[2] * u[2]
    du[2] = -û[3] * u[2] + û[4] * u[1]* u[2]
end

This gave me the following error

{
	"name": "TypeError",
	"message": "TypeError: in cfunction, expected Union{}, got a value of type Nothing",
	"stack": "TypeError: in cfunction, expected Union{}, got a value of type Nothing



Stacktrace:

  [1] macro expansion

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappers\\Q5cBx\\src\\FunctionWrappers.jl:137 [inlined]

  [2] do_ccall

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappers\\Q5cBx\\src\\FunctionWrappers.jl:125 [inlined]

  [3] FunctionWrapper

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappers\\Q5cBx\\src\\FunctionWrappers.jl:144 [inlined]

  [4] _call

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappersWrappers\\9XR0m\\src\\FunctionWrappersWrappers.jl:12 [inlined]

  [5] FunctionWrappersWrapper

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappersWrappers\\9XR0m\\src\\FunctionWrappersWrappers.jl:10 [inlined]

  [6] Void

    @ C:\\Users\\minyo\\.julia\\packages\\SciMLBase\\SDjaO\\src\\utils.jl:482 [inlined]

  [7] (::FunctionWrappers.CallWrapper{Nothing})(f::SciMLBase.Void{FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}}, arg1::Vector{Float64}, arg2::Vector{Float64}, arg3::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, arg4::Float64)

    @ FunctionWrappers C:\\Users\\minyo\\.julia\\packages\\FunctionWrappers\\Q5cBx\\src\\FunctionWrappers.jl:65

  [8] macro expansion

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappers\\Q5cBx\\src\\FunctionWrappers.jl:137 [inlined]

  [9] do_ccall

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappers\\Q5cBx\\src\\FunctionWrappers.jl:125 [inlined]

 [10] FunctionWrapper

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappers\\Q5cBx\\src\\FunctionWrappers.jl:144 [inlined]

 [11] _call

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappersWrappers\\9XR0m\\src\\FunctionWrappersWrappers.jl:12 [inlined]

 [12] FunctionWrappersWrapper

    @ C:\\Users\\minyo\\.julia\\packages\\FunctionWrappersWrappers\\9XR0m\\src\\FunctionWrappersWrappers.jl:10 [inlined]

 [13] ODEFunction

    @ C:\\Users\\minyo\\.julia\\packages\\SciMLBase\\SDjaO\\src\\scimlfunctions.jl:2296 [inlined]

 [14] ode_determine_initdt(u0::Vector{Float64}, t::Float64, tdir::Float64, dtmax::Float64, ...

    @ OrdinaryDiffEq C:\\Users\\minyo\\.julia\\packages\\OrdinaryDiffEq\\ZbQoo\\src\\initdt.jl:53

 [15] auto_dt_reset!

    @ C:\\Users\\minyo\\.julia\\packages\\OrdinaryDiffEq\\ZbQoo\\src\\integrators\\integrator_interface.jl:453 [inlined]

 [16] handle_dt!(integrator::OrdinaryDiffEq.ODEIntegrator{Vern7{typeof(OrdinaryDiffEq.trivial_limiter!), ...

    @ OrdinaryDiffEq C:\\Users\\minyo\\.julia\\packages\\OrdinaryDiffEq\\ZbQoo\\src\\solve.jl:571

 [17] __init(prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ...

    @ OrdinaryDiffEq C:\\Users\\minyo\\.julia\\packages\\OrdinaryDiffEq\\ZbQoo\\src\\solve.jl:533

 [18] __init (repeats 5 times)

    @ C:\\Users\\minyo\\.julia\\packages\\OrdinaryDiffEq\\ZbQoo\\src\\solve.jl:11 [inlined]

 [19] #__solve#761

    @ C:\\Users\\minyo\\.julia\\packages\\OrdinaryDiffEq\\ZbQoo\\src\\solve.jl:6 [inlined]

 [20] __solve

    @ C:\\Users\\minyo\\.julia\\packages\\OrdinaryDiffEq\\ZbQoo\\src\\solve.jl:1 [inlined]

 [21] solve_call(_prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ...

    @ DiffEqBase C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:612

 [22] solve_call

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:569 [inlined]

 [23] #solve_up#53

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:1080 [inlined]

 [24] solve_up

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:1066 [inlined]

 [25] #solve#51

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:1003 [inlined]

 [26] _concrete_solve_adjoint(::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ...

    @ SciMLSensitivity C:\\Users\\minyo\\.julia\\packages\\SciMLSensitivity\\rXkM4\\src\\concrete_solve.jl:351

 [27] _concrete_solve_adjoint

    @ C:\\Users\\minyo\\.julia\\packages\\SciMLSensitivity\\rXkM4\\src\\concrete_solve.jl:297 [inlined]

 [28] #_solve_adjoint#75

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:1537 [inlined]

 [29] _solve_adjoint

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:1510 [inlined]

 [30] #rrule#6

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\ext\\DiffEqBaseChainRulesCoreExt.jl:26 [inlined]

 [31] rrule

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\ext\\DiffEqBaseChainRulesCoreExt.jl:22 [inlined]

 [32] rrule

    @ C:\\Users\\minyo\\.julia\\packages\\ChainRulesCore\\zgT0R\\src\\rules.jl:140 [inlined]

 [33] chain_rrule_kw

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\chainrules.jl:235 [inlined]

 [34] macro expansion

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0 [inlined]

 [35] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ...

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:87

 [36] _apply(::Function, ::Vararg{Any})

    @ Core .\\boot.jl:838

 [37] adjoint

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\lib\\lib.jl:203 [inlined]

 [38] _pullback

    @ C:\\Users\\minyo\\.julia\\packages\\ZygoteRules\\M4xmc\\src\\adjoint.jl:67 [inlined]

 [39] #solve#51

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:1003 [inlined]

 [40] _pullback(::Zygote.Context{false}, ::DiffEqBase.var\"##solve#51\", ::QuadratureAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}, ::Nothing, ::Nothing, ::Val{true}, ::@Kwargs{saveat::Vector{Float64}, abstol::Float64, reltol::Float64}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(nn_dynamics!), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Vern7{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [41] _apply(::Function, ::Vararg{Any})

    @ Core .\\boot.jl:838

 [42] adjoint

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\lib\\lib.jl:203 [inlined]

 [43] _pullback

    @ C:\\Users\\minyo\\.julia\\packages\\ZygoteRules\\M4xmc\\src\\adjoint.jl:67 [inlined]

 [44] solve

    @ C:\\Users\\minyo\\.julia\\packages\\DiffEqBase\\yM6LF\\src\\solve.jl:993 [inlined]

 [45] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{saveat::Vector{Float64}, abstol::Float64, reltol::Float64, sensealg::QuadratureAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ODEFunction{true, SciMLBase.AutoSpecialize, typeof(nn_dynamics!), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Vern7{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [46] predict

    @ c:\\Users\\minyo\\Desktop\\Masters Thesis\\1 compartment model\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W4sZmlsZQ==.jl:3 [inlined]

 [47] _pullback(::Zygote.Context{false}, ::typeof(predict), ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ::Vector{Float64}, ::Vector{Float64})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [48] predict

    @ c:\\Users\\minyo\\Desktop\\Masters Thesis\\1 compartment model\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W4sZmlsZQ==.jl:2 [inlined]

 [49] _pullback(ctx::Zygote.Context{false}, f::typeof(predict), args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [50] loss

    @ c:\\Users\\minyo\\Desktop\\Masters Thesis\\1 compartment model\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W4sZmlsZQ==.jl:8 [inlined]

 [51] _pullback(ctx::Zygote.Context{false}, f::typeof(loss), args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [52] #21

    @ c:\\Users\\minyo\\Desktop\\Masters Thesis\\1 compartment model\\jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W4sZmlsZQ==.jl:21 [inlined]

 [53] _pullback(::Zygote.Context{false}, ::var\"#21#22\", ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ::SciMLBase.NullParameters)

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [54] _apply

    @ .\\boot.jl:838 [inlined]

 [55] adjoint

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\lib\\lib.jl:203 [inlined]

 [56] _pullback

    @ C:\\Users\\minyo\\.julia\\packages\\ZygoteRules\\M4xmc\\src\\adjoint.jl:67 [inlined]

 [57] OptimizationFunction

    @ C:\\Users\\minyo\\.julia\\packages\\SciMLBase\\SDjaO\\src\\scimlfunctions.jl:3762 [inlined]

 [58] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, AutoZygote, var\"#21#22\", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ::SciMLBase.NullParameters)

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [59] _apply(::Function, ::Vararg{Any})

    @ Core .\\boot.jl:838

 [60] adjoint

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\lib\\lib.jl:203 [inlined]

 [61] _pullback

    @ C:\\Users\\minyo\\.julia\\packages\\ZygoteRules\\M4xmc\\src\\adjoint.jl:67 [inlined]

 [62] #37

    @ C:\\Users\\minyo\\.julia\\packages\\OptimizationBase\\rRpJs\\ext\\OptimizationZygoteExt.jl:90 [inlined]

 [63] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var\"#37#55\"{OptimizationFunction{true, AutoZygote, var\"#21#22\", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationBase.ReInitCache{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, SciMLBase.NullParameters}}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [64] _apply(::Function, ::Vararg{Any})

    @ Core .\\boot.jl:838

 [65] adjoint

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\lib\\lib.jl:203 [inlined]

 [66] _pullback

    @ C:\\Users\\minyo\\.julia\\packages\\ZygoteRules\\M4xmc\\src\\adjoint.jl:67 [inlined]

 [67] #39

    @ C:\\Users\\minyo\\.julia\\packages\\OptimizationBase\\rRpJs\\ext\\OptimizationZygoteExt.jl:93 [inlined]

 [68] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var\"#39#57\"{Tuple{}, OptimizationZygoteExt.var\"#37#55\"{OptimizationFunction{true, AutoZygote, var\"#21#22\", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationBase.ReInitCache{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, SciMLBase.NullParameters}}}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface2.jl:0

 [69] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface.jl:90

 [70] pullback

    @ C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface.jl:88 [inlined]

 [71] gradient(f::Function, args::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}})

    @ Zygote C:\\Users\\minyo\\.julia\\packages\\Zygote\
sBv0\\src\\compiler\\interface.jl:147

 [72] (::OptimizationZygoteExt.var\"#38#56\"{OptimizationZygoteExt.var\"#37#55\"{OptimizationFunction{true, AutoZygote, var\"#21#22\", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, OptimizationBase.ReInitCache{ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, SciMLBase.NullParameters}}})(::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, ::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}})

    @ OptimizationZygoteExt C:\\Users\\minyo\\.julia\\packages\\OptimizationBase\\rRpJs\\ext\\OptimizationZygoteExt.jl:93

 [73] macro expansion

    @ C:\\Users\\minyo\\.julia\\packages\\OptimizationOptimisers\\AOkbT\\src\\OptimizationOptimisers.jl:68 [inlined]

 [74] macro expansion

    @ C:\\Users\\minyo\\.julia\\packages\\Optimization\\5DEdF\\src\\utils.jl:32 [inlined]

 [75] __solve(cache::OptimizationCache{OptimizationFunction{true, AutoZygote, var\"#21#22\", ...

    @ OptimizationOptimisers C:\\Users\\minyo\\.julia\\packages\\OptimizationOptimisers\\AOkbT\\src\\OptimizationOptimisers.jl:66

 [76] solve!(cache::OptimizationCache{OptimizationFunction{true, AutoZygote, var\"#21#22\", ...

    @ SciMLBase C:\\Users\\minyo\\.julia\\packages\\SciMLBase\\SDjaO\\src\\solve.jl:188

 [77] solve(::OptimizationProblem{true, OptimizationFunction{true, AutoZygote, var\"#21#22\", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(layer_1 = ViewAxis(1:24, Axis(weight = ViewAxis(1:16, ShapedAxis((8, 2))), bias = ViewAxis(17:24, ShapedAxis((8, 1))))), layer_2 = ViewAxis(25:96, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_3 = ViewAxis(97:168, Axis(weight = ViewAxis(1:64, ShapedAxis((8, 8))), bias = ViewAxis(65:72, ShapedAxis((8, 1))))), layer_4 = ViewAxis(169:204, Axis(weight = ViewAxis(1:32, ShapedAxis((4, 8))), bias = ViewAxis(33:36, ShapedAxis((4, 1))))))}}}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, @Kwargs{}}, ::Optimisers.Adam; kwargs::@Kwargs{callback::var\"#19#20\", maxiters::Int64})

    @ SciMLBase C:\\Users\\minyo\\.julia\\packages\\SciMLBase\\SDjaO\\src\\solve.jl:96"
}

Is the problem that by replacing all parameters with a NN, we create a situation where there are too many possible solutions and it therefore does not converge?

Yes. If you do that, you’d need to regularize the NN so that you prefer to have parameter values in the other parts, otherwise the NN may dominate.

Thanks for the suggestion. I implemented basic L2 regulariztion in the loss function and the Lotka-Volterra example were all 4 parameters were unknown was able to successfully converge to the correct solution.

I have tried the same approach for my task, but despite the regularization, it fails to fit the curve.

The result that I get is dependent on the size of my NN. Through trial and error, the configuration of 2 hidden layers with 15 nodes each has delivered the best result (although still not very good). But considering the size and complexity of the dataset, I feel that this is already too big and that methods such as PEM and multiple shooting should not be necessary.

Do you have any suggestions on what the problem could be?

original data
estimated output

It could be a local minima? Try ADAM part for longer

I have tried increasing he maxiters of ADAM to 50000. It should maybe be noted, that around iteration 19100, the loss starts increasing again, with the 50000 iteration loss being 76 and the training loss after 50001 iterations being 8.4. This maybe shows that a local minima is indeed the problem?

The result after 50000 iterations was the same as 5000 iterations