Parameters of the neural network not updating after training in a Neural ODE problem

Hello,
I have been struggling with this problem for some time now. I have constructed an example neural ode problem related to my work using the following example. https://docs.sciml.ai/Overview/stable/showcase/missing_physics/

During training, the loss decreases which I am monitoring. After training, the parameters does not get saved properly. Let me explain the issue. The following is my code which can be run on any system without needing data sets.

using OrdinaryDiffEq , SciMLSensitivity ,Optimization, OptimizationOptimisers,OptimizationOptimJL, LineSearches
using Statistics
using StableRNGs, Lux, Zygote , Plots , ComponentArrays

rng = StableRNG(11)

# Generating training data
function actualODE!(du,u,p,t,T∞,I)
    
    Cbat  =  5*3600 
    du[1] = -I/Cbat

    C₁ = -0.00153 # Unit is s-1
    C₂ = 0.020306 # Unit is K/J

    R0 = 0.03 # Resistance set a 30mohm

    Qgen =(I^2)*R0

    du[2] = (C₁*(u[2]-T∞)) + (C₂*Qgen)

end

t1 = collect(0:1:3400)
T∞1,I1 = 298.15,5

t2 = collect(0:1:1800)
T∞2,I2 = 298.15,10

actualODE1!(du,u,p,t) = actualODE!(du,u,p,t,T∞1,I1)
actualODE2!(du,u,p,t) = actualODE!(du,u,p,t,T∞2,I2)


prob_act1 = ODEProblem(actualODE1!,[1.0,T∞1],(t1[1],t1[end]))
solution1 = solve(prob_act1,Tsit5(),saveat = t1)
X1 = Array(solution1)
T1 = X1[2,:]

prob_act2 = ODEProblem(actualODE2!,[1.0,T∞2],(t2[1],t2[end]))
solution2 = solve(prob_act2,Tsit5(),saveat = t2)
X2 = Array(solution2)
T2 = X2[2,:]

# Plotting the results
plot(solution1[2,:],color = :black,label = ["True Data 1C 25" nothing])
plot!(solution2[2,:],color = :red,label = ["True Data 2C 25" nothing])

# Defining the neural network
const U = Lux.Chain(Lux.Dense(3,20,tanh),Lux.Dense(20,20,tanh),Lux.Dense(20,1))
_para,st = Lux.setup(rng,U)
const _st = st

function NODE_model!(du,u,p,t,T∞,I)

    Cbat = 5*3600
    du[1] = -I/Cbat

    C₁ = -0.00153
    C₂ = 0.020306

    G = I*(U([u[1],u[2],I],p,_st)[1][1])

    du[2] = (C₁*(u[2]-T∞)) + (C₂*G)

end

NODE_model1!(du,u,p,t) = NODE_model!(du,u,p,t,T∞1,I1)
NODE_model2!(du,u,p,t) = NODE_model!(du,u,p,t,T∞2,I2)

para_init = _para
prob1 = ODEProblem(NODE_model1!,[1.0,T∞1],(t1[1],t1[end]),_para)
prob2 = ODEProblem(NODE_model2!,[1.0,T∞2],(t2[1],t2[end]),_para)

losses = Float64[]
function loss(θ)

    N_dataset =2
    α = length(losses)+1
    #α =1 # Change the α here after training to check the losses in each case

    if α%N_dataset ==0
      _prob = remake(prob1,p=θ)
      _sol = Array(solve(_prob,Tsit5(),saveat = 1,abstol = 1e-6, reltol = 1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
      loss1 = mean(abs2,T1.-_sol[2,:])
      return loss1

    elseif α%N_dataset == 1
      _prob = remake(prob2,p=θ)
      _sol = Array(solve(_prob,Tsit5(),saveat = 1,abstol = 1e-6, reltol = 1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
      loss2 = mean(abs2,T2.-_sol[2,:])
      return loss2

    end     
end


callback = function(state,l)
    push!(losses,l)
    println("MSE Loss at iteration $(length(losses)) is $(l))")
    
    return false

end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x),adtype)
optprob = Optimization.OptimizationProblem(optf,ComponentVector{Float64}(_para))

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(),callback = callback,maxiters = 150 )

p_trained = res1.u
loss_final = loss(p_trained)

The training is completed and the loss is monitored. This how the loss at different iterations is

MSE Loss at iteration 142 is 0.07695669049562295)
MSE Loss at iteration 143 is 0.001179701467909206)
MSE Loss at iteration 144 is 0.07591100672024427)
MSE Loss at iteration 145 is 0.001091751465609696)
MSE Loss at iteration 146 is 0.07488918202752927)
MSE Loss at iteration 147 is 0.0010448951067463339)
MSE Loss at iteration 148 is 0.07386806185144365)
MSE Loss at iteration 149 is 0.0010384274304413134)
MSE Loss at iteration 150 is 0.07283458750288192)
MSE Loss at iteration 151 is 0.0007092846014555851

As you can see MSE loss of both the datasets have reached well below 0.1. After training I checked the MSE loss again. One data set has the final MSE loss show ie 0.0007092846014555725 while the MSE of the other data set goes up to 0.10385923586943051 eventhough during training both the losses reaches well below 0.1.

This is an example case for the issue. When I run my code with 6 actual datasets eventhough the loss of the 6 data goes below 0.5, after training only the last MSE loss is retained and all other losses goes up.

One other odd thing I have noticed is this. The weights of the neural network is initialized as the following using a random number generator

u0: ComponentVector{Float64}(layer_1 = (weight = [0.5868697166442871 -0.15597978234291077 -0.9511871933937073; -0.9119327664375305 -0.6244392991065979 0.033625759184360504; … ; -0.43237248063087463 1.2632182836532593 -0.4303165078163147; -0.6320686936378479 0.43183207511901855 -1.2375504970550537], bias = [0.048540983349084854, 0.06472337990999222, 0.1911112517118454, 0.1919097602367401, -0.4226461946964264, 0.19537760317325592, 0.4472426474094391, -0.26705870032310486, -0.04794495552778244, 0.561735987663269, 0.5390963554382324, -0.5670568346977234, -0.12485334277153015, -0.351481169462204, -0.42042452096939087, 0.48453590273857117, 0.2708824872970581, 0.3420274257659912, 0.03861868381500244, 0.29712456464767456]), layer_2 = (weight = [0.15543659031391144 -0.298647940158844 … -0.4323674440383911 -0.3045310080051422; -0.14312994480133057 0.10419807583093643 … 0.6008198261260986 -0.06152293458580971; … ; -0.10994387418031693 -0.2899731397628784 … -0.2709127962589264 -0.08749561756849289; 0.23794974386692047 -0.5665941834449768 … -0.615195631980896 -0.6016205549240112], bias = [-0.011972705833613873, -0.18124336004257202, 0.14369475841522217, -0.1350693553686142, 0.17351830005645752, 0.019232628867030144, -0.11016832292079926, -0.032869044691324234, 0.06951241940259933, 0.039028871804475784, -0.036698393523693085, 0.09698633849620819, 0.10837501287460327, -0.14214731752872467, -0.018525179475545883, -0.14776986837387085, 0.12666049599647522, -0.047814905643463135, -0.12171175330877304, 0.20279602706432343]), layer_3 = (weight = [-0.05284392088651657 0.0632457360625267 … 0.00013056751049589366 0.006312578916549683], bias = [-0.09865286946296692]))

After training the weights are the following

ComponentVector{Float64}(layer_1 = (weight = [0.5868697166442871 -0.15597978234291077 -0.9511871933937073; -0.9119327664375305 -0.6244392991065979 0.033625759184360504; … ; -0.43237248063087463 1.2632182836532593 -0.4303165078163147; -0.6320686936378479 0.43183207511901855 -1.2375504970550537], bias = [0.048540983349084854, 0.06472337990999222, 0.1911112517118454, 0.1919097602367401, -0.4226461946964264, 0.19537760317325592, 0.4472426474094391, -0.26705870032310486, -0.04794495552778244, 0.561735987663269, 0.5390963554382324, -0.5670568346977234, -0.12485334277153015, -0.351481169462204, -0.42042452096939087, 0.4859393046753676, 0.2708824872970581, 0.3420274257659912, 0.03861868381500244, 0.29712456464767456]), layer_2 = (weight = [0.14483104762371674 -0.30925348284903875 … -0.42176190134819636 -0.29392546531494745; -0.15497718719912626 0.09235083343314082 … 0.6126670685238943 -0.049675692188014084; … ; -0.09865370778360218 -0.27868297336616343 … -0.2822029626556414 -0.09878578396520764; 0.23330477827296994 -0.571239149038927 … -0.6105506663869458 -0.596975589330061], bias = [-0.0013671631434192485, -0.16939611764477633, 0.14362018560145273, -0.1469548352774139, 0.1774907000806722, 0.01682939081719727, -0.10395751115977608, -0.03793285086052922, 0.06884556755576844, 0.03676230269826743, -0.03201640648936814, 0.08544798266208918, 0.1023016953617768, -0.133019767249293, -0.028572406439131247, -0.15287907036607892, 0.13084924844996712, -0.04111298450642831, -0.13300191970548778, 0.20744099265827395]), layer_3 = (weight = [-0.0495944483708906 0.05788266685919967 … 0.003488253920225086 0.0120929793162927], bias = [-0.10328794130455402]))

You can clearly see the similarity between the first terms of the two even though the inner weights may have changed. Whatever random initialization I use, there is huge similarity between the initial weights and the final weights.

I get one warning message when I run the optimization. I don’t know whether it is related to this.

 Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│ 
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│ 
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt C:\Users\Kalath_A\.julia\packages\LuxCore\8mVob\ext\LuxCoreArrayInterfaceReverseDiffExt.jl:10

Please help me. I have been trying to get the code up and running for quite some time now. But issues keeps popping up. I am new to Julia so I am not understanding what the issue here is.

(post deleted by author)

Those aren’t simultaniously true, so the minimum is not guaranteed to satisfy both, just one.