Parameters of the neural network not updating after training in a Neural ODE problem

Hello,
I have been struggling with this problem for some time now. I have constructed an example neural ode problem related to my work using the following example. https://docs.sciml.ai/Overview/stable/showcase/missing_physics/

During training, the loss decreases which I am monitoring. After training, the parameters does not get saved properly. Let me explain the issue. The following is my code which can be run on any system without needing data sets.

using OrdinaryDiffEq , SciMLSensitivity ,Optimization, OptimizationOptimisers,OptimizationOptimJL, LineSearches
using Statistics
using StableRNGs, Lux, Zygote , Plots , ComponentArrays

rng = StableRNG(11)

# Generating training data
function actualODE!(du,u,p,t,T∞,I)
    
    Cbat  =  5*3600 
    du[1] = -I/Cbat

    C₁ = -0.00153 # Unit is s-1
    C₂ = 0.020306 # Unit is K/J

    R0 = 0.03 # Resistance set a 30mohm

    Qgen =(I^2)*R0

    du[2] = (C₁*(u[2]-T∞)) + (C₂*Qgen)

end

t1 = collect(0:1:3400)
T∞1,I1 = 298.15,5

t2 = collect(0:1:1800)
T∞2,I2 = 298.15,10

actualODE1!(du,u,p,t) = actualODE!(du,u,p,t,T∞1,I1)
actualODE2!(du,u,p,t) = actualODE!(du,u,p,t,T∞2,I2)


prob_act1 = ODEProblem(actualODE1!,[1.0,T∞1],(t1[1],t1[end]))
solution1 = solve(prob_act1,Tsit5(),saveat = t1)
X1 = Array(solution1)
T1 = X1[2,:]

prob_act2 = ODEProblem(actualODE2!,[1.0,T∞2],(t2[1],t2[end]))
solution2 = solve(prob_act2,Tsit5(),saveat = t2)
X2 = Array(solution2)
T2 = X2[2,:]

# Plotting the results
plot(solution1[2,:],color = :black,label = ["True Data 1C 25" nothing])
plot!(solution2[2,:],color = :red,label = ["True Data 2C 25" nothing])

# Defining the neural network
const U = Lux.Chain(Lux.Dense(3,20,tanh),Lux.Dense(20,20,tanh),Lux.Dense(20,1))
_para,st = Lux.setup(rng,U)
const _st = st

function NODE_model!(du,u,p,t,T∞,I)

    Cbat = 5*3600
    du[1] = -I/Cbat

    C₁ = -0.00153
    C₂ = 0.020306

    G = I*(U([u[1],u[2],I],p,_st)[1][1])

    du[2] = (C₁*(u[2]-T∞)) + (C₂*G)

end

NODE_model1!(du,u,p,t) = NODE_model!(du,u,p,t,T∞1,I1)
NODE_model2!(du,u,p,t) = NODE_model!(du,u,p,t,T∞2,I2)

para_init = _para
prob1 = ODEProblem(NODE_model1!,[1.0,T∞1],(t1[1],t1[end]),_para)
prob2 = ODEProblem(NODE_model2!,[1.0,T∞2],(t2[1],t2[end]),_para)

losses = Float64[]
function loss(θ)

    N_dataset =2
    α = length(losses)+1
    #α =1 # Change the α here after training to check the losses in each case

    if α%N_dataset ==0
      _prob = remake(prob1,p=θ)
      _sol = Array(solve(_prob,Tsit5(),saveat = 1,abstol = 1e-6, reltol = 1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
      loss1 = mean(abs2,T1.-_sol[2,:])
      return loss1

    elseif α%N_dataset == 1
      _prob = remake(prob2,p=θ)
      _sol = Array(solve(_prob,Tsit5(),saveat = 1,abstol = 1e-6, reltol = 1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
      loss2 = mean(abs2,T2.-_sol[2,:])
      return loss2

    end     
end


callback = function(state,l)
    push!(losses,l)
    println("MSE Loss at iteration $(length(losses)) is $(l))")
    
    return false

end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss(x),adtype)
optprob = Optimization.OptimizationProblem(optf,ComponentVector{Float64}(_para))

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(),callback = callback,maxiters = 150 )

p_trained = res1.u
loss_final = loss(p_trained)

The training is completed and the loss is monitored. This how the loss at different iterations is

MSE Loss at iteration 142 is 0.07695669049562295)
MSE Loss at iteration 143 is 0.001179701467909206)
MSE Loss at iteration 144 is 0.07591100672024427)
MSE Loss at iteration 145 is 0.001091751465609696)
MSE Loss at iteration 146 is 0.07488918202752927)
MSE Loss at iteration 147 is 0.0010448951067463339)
MSE Loss at iteration 148 is 0.07386806185144365)
MSE Loss at iteration 149 is 0.0010384274304413134)
MSE Loss at iteration 150 is 0.07283458750288192)
MSE Loss at iteration 151 is 0.0007092846014555851

As you can see MSE loss of both the datasets have reached well below 0.1. After training I checked the MSE loss again. One data set has the final MSE loss show ie 0.0007092846014555725 while the MSE of the other data set goes up to 0.10385923586943051 eventhough during training both the losses reaches well below 0.1.

This is an example case for the issue. When I run my code with 6 actual datasets eventhough the loss of the 6 data goes below 0.5, after training only the last MSE loss is retained and all other losses goes up.

One other odd thing I have noticed is this. The weights of the neural network is initialized as the following using a random number generator

u0: ComponentVector{Float64}(layer_1 = (weight = [0.5868697166442871 -0.15597978234291077 -0.9511871933937073; -0.9119327664375305 -0.6244392991065979 0.033625759184360504; … ; -0.43237248063087463 1.2632182836532593 -0.4303165078163147; -0.6320686936378479 0.43183207511901855 -1.2375504970550537], bias = [0.048540983349084854, 0.06472337990999222, 0.1911112517118454, 0.1919097602367401, -0.4226461946964264, 0.19537760317325592, 0.4472426474094391, -0.26705870032310486, -0.04794495552778244, 0.561735987663269, 0.5390963554382324, -0.5670568346977234, -0.12485334277153015, -0.351481169462204, -0.42042452096939087, 0.48453590273857117, 0.2708824872970581, 0.3420274257659912, 0.03861868381500244, 0.29712456464767456]), layer_2 = (weight = [0.15543659031391144 -0.298647940158844 … -0.4323674440383911 -0.3045310080051422; -0.14312994480133057 0.10419807583093643 … 0.6008198261260986 -0.06152293458580971; … ; -0.10994387418031693 -0.2899731397628784 … -0.2709127962589264 -0.08749561756849289; 0.23794974386692047 -0.5665941834449768 … -0.615195631980896 -0.6016205549240112], bias = [-0.011972705833613873, -0.18124336004257202, 0.14369475841522217, -0.1350693553686142, 0.17351830005645752, 0.019232628867030144, -0.11016832292079926, -0.032869044691324234, 0.06951241940259933, 0.039028871804475784, -0.036698393523693085, 0.09698633849620819, 0.10837501287460327, -0.14214731752872467, -0.018525179475545883, -0.14776986837387085, 0.12666049599647522, -0.047814905643463135, -0.12171175330877304, 0.20279602706432343]), layer_3 = (weight = [-0.05284392088651657 0.0632457360625267 … 0.00013056751049589366 0.006312578916549683], bias = [-0.09865286946296692]))

After training the weights are the following

ComponentVector{Float64}(layer_1 = (weight = [0.5868697166442871 -0.15597978234291077 -0.9511871933937073; -0.9119327664375305 -0.6244392991065979 0.033625759184360504; … ; -0.43237248063087463 1.2632182836532593 -0.4303165078163147; -0.6320686936378479 0.43183207511901855 -1.2375504970550537], bias = [0.048540983349084854, 0.06472337990999222, 0.1911112517118454, 0.1919097602367401, -0.4226461946964264, 0.19537760317325592, 0.4472426474094391, -0.26705870032310486, -0.04794495552778244, 0.561735987663269, 0.5390963554382324, -0.5670568346977234, -0.12485334277153015, -0.351481169462204, -0.42042452096939087, 0.4859393046753676, 0.2708824872970581, 0.3420274257659912, 0.03861868381500244, 0.29712456464767456]), layer_2 = (weight = [0.14483104762371674 -0.30925348284903875 … -0.42176190134819636 -0.29392546531494745; -0.15497718719912626 0.09235083343314082 … 0.6126670685238943 -0.049675692188014084; … ; -0.09865370778360218 -0.27868297336616343 … -0.2822029626556414 -0.09878578396520764; 0.23330477827296994 -0.571239149038927 … -0.6105506663869458 -0.596975589330061], bias = [-0.0013671631434192485, -0.16939611764477633, 0.14362018560145273, -0.1469548352774139, 0.1774907000806722, 0.01682939081719727, -0.10395751115977608, -0.03793285086052922, 0.06884556755576844, 0.03676230269826743, -0.03201640648936814, 0.08544798266208918, 0.1023016953617768, -0.133019767249293, -0.028572406439131247, -0.15287907036607892, 0.13084924844996712, -0.04111298450642831, -0.13300191970548778, 0.20744099265827395]), layer_3 = (weight = [-0.0495944483708906 0.05788266685919967 … 0.003488253920225086 0.0120929793162927], bias = [-0.10328794130455402]))

You can clearly see the similarity between the first terms of the two even though the inner weights may have changed. Whatever random initialization I use, there is huge similarity between the initial weights and the final weights.

I get one warning message when I run the optimization. I don’t know whether it is related to this.

 Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│ 
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│ 
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt C:\Users\Kalath_A\.julia\packages\LuxCore\8mVob\ext\LuxCoreArrayInterfaceReverseDiffExt.jl:10

Please help me. I have been trying to get the code up and running for quite some time now. But issues keeps popping up. I am new to Julia so I am not understanding what the issue here is.

Those aren’t simultaniously true, so the minimum is not guaranteed to satisfy both, just one.

It looks like you are trying to fit two problems once and that this gives issues.

This is not the way to do it if you want to reuse the code to fit multiple problems of the same type. You should instead wrap the code in a function that takes and ODEProblem and dataset as input and then returns a solution.

You probably want to do something like this instead.


function solve_my_problem(T_i, prob_i)
     # alot of the logic goes here
     res_i = Optimization.solve(optprob, OptimizationOptimisers.Adam(),callback = callback,maxiters = 150 )
     return res_I
end

res_1 = solve_my_problem(T_1, prob_1)
res_2 = solve_my_problem(T_2, prob_2)

Thank you for all your replies. I changed the loss function to the following

function totalloss_NODE(θ)
  total_error = 0
    
    data_points = [
        (prob1, t1, T1),
        (prob2, t2, T2)
    ]

    for (prob,t,T) in data_points
        _prob = remake(prob,p=θ)
        _sol = Array(solve(_prob,Tsit5(),saveat = t,abstol = 1e-6, reltol = 1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        error = mean(abs2,T.-_sol[2,:])
        total_error = total_error + error
    end

    return total_error
end

The total loss is decreasing in this case. But I am facing a weird issue which I mentioned earlier also.

The initial weights are the following

u0: ComponentVector{Float64}(layer_1 = (weight = [0.5868697166442871 -0.15597978234291077 -0.9511871933937073; -0.9119327664375305 -0.6244392991065979 0.033625759184360504; … ; -0.43237248063087463 1.2632182836532593 -0.4303165078163147; -0.6320686936378479 0.43183207511901855 -1.2375504970550537], bias = [0.048540983349084854, 0.06472337990999222, 0.1911112517118454, 0.1919097602367401, -0.4226461946964264, 0.19537760317325592, 0.4472426474094391, -0.26705870032310486, -0.04794495552778244, 0.561735987663269, 0.5390963554382324, -0.5670568346977234, -0.12485334277153015, -0.351481169462204, -0.42042452096939087, 0.48453590273857117, 0.2708824872970581, 0.3420274257659912, 0.03861868381500244, 0.29712456464767456]), layer_2 = (weight = [0.15543659031391144 -0.298647940158844 … -0.4323674440383911 -0.3045310080051422; -0.14312994480133057 0.10419807583093643 … 0.6008198261260986 -0.06152293458580971; … ; -0.10994387418031693 -0.2899731397628784 … -0.2709127962589264 -0.08749561756849289; 0.23794974386692047 -0.5665941834449768 … -0.615195631980896 -0.6016205549240112], bias = [-0.011972705833613873, -0.18124336004257202, 0.14369475841522217, -0.1350693553686142, 0.17351830005645752, 0.019232628867030144, -0.11016832292079926, -0.032869044691324234, 0.06951241940259933, 0.039028871804475784, -0.036698393523693085, 0.09698633849620819, 0.10837501287460327, -0.14214731752872467, -0.018525179475545883, -0.14776986837387085, 0.12666049599647522, -0.047814905643463135, -0.12171175330877304, 0.20279602706432343]), layer_3 = (weight = [-0.05284392088651657 0.0632457360625267 … 0.00013056751049589366 0.006312578916549683], bias = [-0.09865286946296692]))

The final weights are the following

u: ComponentVector{Float64}(layer_1 = (weight = [0.5868697166442871 -0.15597978234291077 -0.9511871933937073; -0.9119327664375305 -0.6244392991065979 0.033625759184360504; … ; -0.43237248063087463 1.2632182836532593 -0.4303165078163147; -0.6320686936378479 0.43183207511901855 -1.2375504970550537], bias = [0.048540983349084854, 0.06472337990999222, 0.1911112517118454, 0.1919097602367401, -0.4226461946964264, 0.19537760317325592, 0.4472426474094391, -0.26705870032310486, -0.04794495552778244, 0.561735987663269, 0.5390963554382324, -0.5670568346977234, -0.12485334277153015, -0.351481169462204, -0.42042452096939087, 0.4868589683576441, 0.2708824872970581, 0.3420274257659912, 0.03861868381500244, 0.29712456464767456]), layer_2 = (weight = [0.14498473755957877 -0.30909979291317674 … -0.42191559128405837 -0.29407915525080947; -0.1555585123938491 0.09176950823841797 … 0.6132483937186175 -0.04909436699329128; … ; -0.10222056418506252 -0.28224982976762386 … -0.27863610625418095 -0.0952189275637473; 0.2332991122470334 -0.5712448150648632 … -0.6105450003610096 -0.5969699233041248], bias = [-0.0015208530792811617, -0.16881479245005349, 0.1437300697154257, -0.14694053411618568, 0.17766376068233203, 0.017130657759010147, -0.10376101776030819, -0.03768100330106865, 0.0691661963397183, 0.03679317102018985, -0.03216338551281637, 0.08544417886605364, 0.10248560437264352, -0.1331008347908573, -0.028067916817913062, -0.15223780645703466, 0.13111423497127933, -0.04099009165763021, -0.12943506330402743, 0.2074466586842105]), layer_3 = (weight = [-0.049640455071107904 0.05800257958944006 … 0.0031882220278853274 0.011832405729039618], bias = [-0.10310971664566078]))

You can clearly see the similarity between the initial weights in layer 1 and the final weights.
I calculate the maximum difference between the weights in the layer 1 before and after and it is 0.0029162193720990714. 150 iterations was done and the final loss was 0.04836904740699785

Now I will initialize the weights of the neural network with another random seed.

So I am using rng = Random.default_rng()

So the initial weights are

u0: ComponentVector{Float64}(layer_1 = (weight = [0.47856250405311584 1.581030249595642 1.1188480854034424; -1.4497549533843994 0.037171244621276855 0.4280388355255127; … ; 0.7885910868644714 -0.09051362425088882 1.4285931587219238; -0.3990892469882965 -1.0080617666244507 1.107702374458313], bias = [0.06890665739774704, -0.25372210144996643, -0.5348736643791199, -0.10333938151597977, 0.5160987973213196, -0.41439175605773926, -0.07321493327617645, 0.020800722762942314, 0.3064466416835785, -0.06632357090711594, 0.015398676507174969, -0.4404720664024353, 0.41855013370513916, 0.5327157378196716, -0.4651794731616974, -0.04877987876534462, 0.24821606278419495, -0.2632736265659332, -0.538227379322052, -0.23011212050914764]), layer_2 = (weight = [0.33575379848480225 0.3013095259666443 … -0.3334592580795288 0.0693252831697464; -0.5196638703346252 0.366669237613678 … -0.06584017723798752 -0.17688481509685516; … ; 0.4689016342163086 0.5789622664451599 … 0.1348973661661148 0.5494667291641235; 0.0032734216656535864 0.04035850241780281 … -0.09371466189622879 0.1560594141483307], bias = [-0.18218952417373657, 0.12098746001720428, -0.14193132519721985, 0.18615198135375977, -0.1665787696838379, -0.21096982061862946, 0.124036505818367, -0.044865578413009644, -0.0767386183142662, -0.06662506610155106, -0.06431668251752853, -0.20610272884368896, -0.09578026086091995, -0.016254406422376633, 0.009613835252821445, 0.11823383718729019, 0.0040360125713050365, 0.11974681168794632, -0.13142141699790955, 0.13018284738063812]), layer_3 = (weight = [0.2988027334213257 0.34214767813682556 … 0.3172217905521393 0.31532689929008484], bias = [0.13682834804058075]))

The final weights are

ComponentVector{Float64}(layer_1 = (weight = [0.47856250405311584 1.581030249595642 1.1188480854034424; -1.4398098512135353 -0.006402671121003868 0.4038378678181668; … ; 0.8124260000878679 -0.061901950664751994 1.4565092746810522; -0.3990892469882965 -1.0080617666244507 1.107702374458313], bias = [0.06890665739774704, 
-0.2927415623305814, -0.5348736643791199, -0.10333938151597977, 0.5160987973213196, -0.41439175605773926, -0.07321493327617645, 0.020800722762942314, 0.3064466416835785, -0.06632357090711594, 0.015398676507174969, -0.4404720664024353, 0.41855013370513916, 0.5327157378196716, -0.4651794731616974, -0.04877987876534462, 0.24821606278419495, -0.2632736265659332, -0.5133064337034984, -0.23011212050914764]), layer_2 = (weight = [0.3372140073963102 0.30151649395592817 … -0.3350094291804675 0.06786507425823829; -0.5182569832841237 0.36710231681829597 … -0.06735650237785648 -0.17829170214735662; … ; 0.4879364086009823 0.5643851506838274 … 0.11575349102748059 0.5304319547794498; 0.008244876496636992 0.03964811365338423 … -0.09881003193514118 0.15108795931734734], bias = [-0.18072931526222827, 0.12239434706770586, -0.13616251096800125, 0.18303288451700964, -0.1633040437016082, -0.21572658531222033, 0.09314691233385652, -0.04193322317720209, -0.06087426865133882, -0.0643839743719584, -0.06196417352244271, -0.20343585404562792, -0.0971020653578692, -0.01905842820824457, 0.005175218665213127, 0.11064482273281677, 0.0012922768510402018, 0.1250351355753582, -0.11238664261323586, 0.13515430221162147]), layer_3 = (weight = [0.29472779396846455 0.33762715153862694 … 0.32071088341659904 0.3191521092563489], bias = [0.14066045149770134]))

You can see the similarities of the weights in the first layer of the neural network. The maximum difference between the weights in first layer before and after and it is 0.04357391574228072. I did 150 iterations and the final loss value is 0.9274434319070786

I find that this issue is hindering the ability of the algorithm to reach a global minimum. In my code where I am using actual datasets this problem can be clearly seen. I used OptimizationFlux package in my earlier code which is deprecated now and I was able to reach a global minimum and get excellent results in Julia. But now this issue comes up and I am not able to reproduce those good results.

Can someone please help me with this issue. I am only getting one warning message while I run the code and I don’t know if its related to this

┌ Warning: Lux.apply(m::AbstractLuxLayer, x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).
│ 
│ 1. If this was not the desired behavior overload the dispatch on `m`.
│ 
│ 2. This might have performance implications. Check which layer was causing this problem using `Lux.Experimental.@debug_mode`.
└ @ LuxCoreArrayInterfaceReverseDiffExt C:\Users\Kalath_A\.julia\packages\LuxCore\8mVob\ext\LuxCoreArrayInterfaceReverseDiffExt.jl:10

Any help would be much appreciated

What code are you running? Give me something I can copy paste that is the code you’re actually running now.

Thank you for the reply. The following is the code that can be copy and pasted to run.

using OrdinaryDiffEq , SciMLSensitivity ,Optimization, OptimizationOptimisers,OptimizationOptimJL, LineSearches
using Statistics
using Lux, Zygote , Plots , ComponentArrays, StableRNGs

rng = StableRNG(11)

# Generating training data
function actualODE!(du,u,p,t,T∞,I)
    
    Cbat  =  5*3600 
    du[1] = -I/Cbat

    C₁ = -0.00153 # Unit is s-1
    C₂ = 0.020306 # Unit is K/J

    R0 = 0.03 # Resistance set a 30mohm

    Qgen =(I^2)*R0

    du[2] = (C₁*(u[2]-T∞)) + (C₂*Qgen)

end

t1 = collect(0:1:3400)
T∞1,I1 = 298.15,5

t2 = collect(0:1:1800)
T∞2,I2 = 298.15,10

actualODE1!(du,u,p,t) = actualODE!(du,u,p,t,T∞1,I1)
actualODE2!(du,u,p,t) = actualODE!(du,u,p,t,T∞2,I2)

prob_act1 = ODEProblem(actualODE1!,[1.0,T∞1],(t1[1],t1[end]))
solution1 = solve(prob_act1,Tsit5(),saveat = t1)
X1 = Array(solution1)
T1 = X1[2,:]

prob_act2 = ODEProblem(actualODE2!,[1.0,T∞2],(t2[1],t2[end]))
solution2 = solve(prob_act2,Tsit5(),saveat = t2)
X2 = Array(solution2)
T2 = X2[2,:]

# Plotting the results
plot(solution1[2,:],color = :black,label = ["True Data 1C 25" nothing])
plot!(solution2[2,:],color = :red,label = ["True Data 2C 25" nothing])

# Defining the neural network
const U = Lux.Chain(Lux.Dense(3,20,tanh),Lux.Dense(20,20,tanh),Lux.Dense(20,1))
_para,st = Lux.setup(rng,U)
const _st = st

function NODE_model!(du,u,p,t,T∞,I)

    Cbat = 5*3600
    du[1] = -I/Cbat

    C₁ = -0.00153
    C₂ = 0.020306

    G = I*(U([u[1],u[2],I],p,_st)[1][1])

    du[2] = (C₁*(u[2]-T∞)) + (C₂*G)

end

NODE_model1!(du,u,p,t) = NODE_model!(du,u,p,t,T∞1,I1)
NODE_model2!(du,u,p,t) = NODE_model!(du,u,p,t,T∞2,I2)

para_init = _para
prob1 = ODEProblem(NODE_model1!,[1.0,T∞1],(t1[1],t1[end]),_para)
prob2 = ODEProblem(NODE_model2!,[1.0,T∞2],(t2[1],t2[end]),_para)

losses = Float64[]

function totalloss_NODE(θ)
  total_error = 0
    
    data_points = [
        (prob1, t1, T1),
        (prob2, t2, T2)
    ]

    for (prob,t,T) in data_points
        _prob = remake(prob,p=θ)
        _sol = Array(solve(_prob,Tsit5(),saveat = t,abstol = 1e-6, reltol = 1e-6,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        error = mean(abs2,T.-_sol[2,:])
        total_error = total_error + error
    end

    return total_error
end

callback = function(state,l)
    push!(losses,l)
    println("MSE Loss at iteration $(length(losses)) is $(l))")
    
    return false

end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> totalloss_NODE(x),adtype)
optprob = Optimization.OptimizationProblem(optf,ComponentVector{Float64}(_para))

res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(),callback = callback,maxiters = 150 )

p_trained = res1.u

julia> ComponentVector{Float32}(_para)
ComponentVector{Float32}(layer_1 = (weight = Float32[0.5868697 -0.15597978 -0.9511872; -0.91193277 -0.6244393 0.03362576; … ; -0.43237248 1.2632183 -0.4303165; -0.6320687 0.43183208 -1.2375505], bias = Float32[0.048540983, 0.06472338, 0.19111125, 0.19190976, -0.4226462, 0.1953776, 0.44724265, -0.2670587, -0.047944956, 0.561736, 0.53909636, -0.56705683, -0.12485334, -0.35148117, -0.42042452, 0.4845359, 0.2708825, 0.34202743, 0.038618684, 0.29712456]), layer_2 = (weight = Float32[0.15543659 -0.29864794 … -0.43236744 -0.304531; -0.14312994 0.104198076 … 0.6008198 -0.061522935; … ; -0.109943874 -0.28997314 … -0.2709128 -0.08749562; 0.23794974 -0.5665942 … -0.61519563 -0.60162055], bias = Float32[-0.011972706, -0.18124336, 0.14369476, -0.13506936, 0.1735183, 0.019232629, -0.11016832, -0.032869045, 0.06951242, 0.03902887, -0.036698394, 0.09698634, 0.10837501, -0.14214732, -0.01852518, -0.14776987, 0.1266605, -0.047814906, -0.12171175, 0.20279603]), layer_3 = (weight = Float32[-0.05284392 0.063245736 … 0.00013056751 0.006312579], bias = Float32[-0.09865287]))

julia> p_trained = res1.u
ComponentVector{Float32}(layer_1 = (weight = Float32[0.5868697 -0.15597978 -0.9511872; -0.91193277 -0.6244393 0.03362576; … ; -0.43237248 1.2632183 -0.4303165; -0.6320687 0.43183208 -1.2375505], bias = Float32[0.048540983, 0.06472338, 0.19111125, 0.19190976, -0.4226462, 0.1953776, 0.44724265, -0.2670587, -0.047944956, 0.561736, 0.53909636, -0.56705683, -0.12485334, -0.35148117, -0.42042452, 0.48685873, 0.2708825, 0.34202743, 0.038618684, 0.29712456]), layer_2 = (weight = Float32[0.14498466 -0.3090997 … -0.42191568 -0.29407924; -0.15555853 0.09176947 … 0.61324835 -0.04909436; … ; -0.10222043 -0.28224963 … -0.2786363 -0.09521906; 0.23329921 -0.57124466 … -0.61054516 -0.5969701], bias = Float32[-0.0015208176, -0.16881478, 0.14373018, -0.14694063, 0.17766376, 0.017130658, -0.10376103, -0.037681032, 0.069166176, 0.036793146, -0.03216336, 0.08544414, 0.102485545, -0.1331008, -0.028067952, -0.15223785, 0.13111421, -0.040990077, -0.12943508, 0.20744656]), layer_3 = (weight = Float32[-0.049640417 0.058002524 … 0.003188238 0.011832428], bias = Float32[-0.10310978]))

It’s not that it’s not updating… it’s just the first layer. @avikpal this seems like a Lux or Zygote bug?

Guess it’s just vanishing gradients with tanh when weights are too large initially. At least starting with smaller weights has every layer updating nicely,

julia> optprob = Optimization.OptimizationProblem(optf, Float32(1e-2) .* ComponentVector(_para))
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = (weight = Float32[0.005868697 -0.0015597978 -0.009511871; -0.009119327 -0.006244393 0.0003362576; …

julia> res1 = Optimization.solve(optprob, OptimizationOptimisers.Adam(),callback = callback,maxiters = 75);

julia> res1.u
ComponentVector{Float32}(layer_1 = (weight = Float32[-0.010741851 -0.01817102 -0.026154043; -0.024623511 -0.021738427 -0.015188978; … 

Ahh yeah that could do it. So something like a gelu or softplus activation function could give better results

Thank you for all your suggestions and help.

I tried gelu activation function. But the same issue persists. So I reduced the weights again as mentioned by multiplying the initial weights with 1e-2. Then the issue went away. The weights were updating nicely. So starting with smaller weights helped in the case of gelu activation function too.

Is there any explanation as to why it occurs to the gelu activation function also?

I encountered another issue too. I applied the code to actual data set. I used Gelu activation function inializing by multiplying it with 1e-3 like the following

optprob = Optimization.OptimizationProblem(optf,Float32(1e-3).*ComponentVector(_para))

I did optimiazation with ADAM for 700 iterations. It was successful. Then I tried BFGS . The following is the loss function and definition of optimization problem for it

# Defining loss function for BFGS
function total_loss(θ)

    total_error = 0
    
    data_points = [
        (prob1, t1, T1),
        (prob2, t2, T2),
        (prob3, t3, T3),
        (prob4, t4, T4),
        (prob5, t5, T5),
        (prob6, t6, T6)
    ]

    for (prob,t,T) in data_points
        _prob = remake(prob,p=θ)
        _sol = Array(solve(_prob,Tsit5(),saveat = t,sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
        # Print dimensions of T and _sol[2,:]
        println("Dimensions of T: ", size(T))
        println("Dimensions of _sol[2,:]: ", size(_sol[2,:]))
        
        error = mean(abs2,T.-_sol[2,:])
        error_norm = error/(T[end] - T[1])
        total_error = total_error + error_norm
        
    end

    return total_error
end

optimiser_ = BFGS()
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> total_loss(x),adtype)
optprob2 = Optimization.OptimizationProblem(optf,p_adam_extra)
res2 = Optimization.solve(optprob2,optimiser_,callback=callback,maxiters=100)
p_bfgs = res2.u 

The dimensions are printed and MSE loss comes first. Then comes a long warning and error message

Dimensions of T: (3379,)
Dimensions of _sol[2,:]: (3379,)
Dimensions of T: (277,)
Dimensions of _sol[2,:]: (277,)
Dimensions of T: (1750,)
Dimensions of _sol[2,:]: (1750,)
Dimensions of T: (3575,)
Dimensions of _sol[2,:]: (3575,)
Dimensions of T: (253,)
Dimensions of _sol[2,:]: (253,)
Dimensions of T: (1812,)
Dimensions of _sol[2,:]: (1812,)
Dimensions of T: (3379,)
Dimensions of _sol[2,:]: (3379,)
Dimensions of T: (277,)
Dimensions of _sol[2,:]: (277,)
Dimensions of T: (1750,)
Dimensions of _sol[2,:]: (1750,)
Dimensions of T: (3575,)
Dimensions of _sol[2,:]: (3575,)
Dimensions of T: (253,)
Dimensions of _sol[2,:]: (253,)
Dimensions of T: (1812,)
Dimensions of _sol[2,:]: (1812,)
MSE Loss at iteration 704 is 0.6876666
┌ Warning: At t=0.052239357203810996, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62785553574862. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (3379,)
Dimensions of _sol[2,:]: (3379,)
┌ Warning: At t=2.2068473457389683, dt was forced below floating point epsilon 4.440892098500626e-16, and step error estimate = 52.62783313598562. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (277,)
Dimensions of _sol[2,:]: (277,)
┌ Warning: At t=0.27594268840767294, dt was forced below floating point epsilon 5.551115123125783e-17, and step error estimate = 52.62785514581653. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (1750,)
Dimensions of _sol[2,:]: (1750,)
┌ Warning: At t=0.049734966557859785, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62785087822632. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (3575,)
Dimensions of _sol[2,:]: (3575,)
┌ Warning: At t=0.05432313460574589, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62789286314034. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (253,)
Dimensions of _sol[2,:]: (253,)
┌ Warning: At t=0.054100158988645994, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62786883864311. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (1812,)
Dimensions of _sol[2,:]: (1812,)
┌ Warning: At t=1699.7619999999988, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 1.6785514216942121e6. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=6177.671000000002, dt was forced below floating point epsilon -5.12e-13, and step error estimate = 7.5892721393927215e6. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=3443.513000000001, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 1.7002288769444497e6. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=1641.9389999999985, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 208373.3217245348. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=6973.073999999999, dt was forced below floating point epsilon -5.12e-13, and step error estimate = 155487.34747782705. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=3251.811999999998, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 1.6710083063827937e6. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=0.052239357203810996, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62785553574862. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (3379,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=2.2068473457389683, dt was forced below floating point epsilon 4.440892098500626e-16, and step error estimate = 52.62783313598562. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (277,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=0.27594268840767294, dt was forced below floating point epsilon 5.551115123125783e-17, and step error estimate = 52.62785514581653. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (1750,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=0.049734966557859785, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62785087822632. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (3575,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=0.05432313460574589, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62789286314034. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (253,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=0.054100158988645994, dt was forced below floating point epsilon 6.938893903907228e-18, and step error estimate = 52.62786883864311. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (1812,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=0.6050178618171094, dt was forced below floating point epsilon 1.1102230246251565e-16, and step error estimate = 52.62783275643544. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (3379,)
Dimensions of _sol[2,:]: (3379,)
┌ Warning: At t=2.737000497297441, dt was forced below floating point epsilon 4.440892098500626e-16, and step error estimate = 52.62786067101712. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (277,)
Dimensions of _sol[2,:]: (277,)
┌ Warning: At t=0.8401324955473453, dt was forced below floating point epsilon 1.1102230246251565e-16, and step error estimate = 52.62788972449586. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (1750,)
Dimensions of _sol[2,:]: (1750,)
┌ Warning: At t=0.5757749197282299, dt was forced below floating point epsilon 1.1102230246251565e-16, and step error estimate = 52.62789933586094. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (3575,)
Dimensions of _sol[2,:]: (3575,)
┌ Warning: At t=0.6327587256368815, dt was forced below floating point epsilon 1.1102230246251565e-16, and step error estimate = 52.62784787958808. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (253,)
Dimensions of _sol[2,:]: (253,)
┌ Warning: At t=0.5991205280351264, dt was forced below floating point epsilon 1.1102230246251565e-16, and step error estimate = 52.62785054144224. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (1812,)
Dimensions of _sol[2,:]: (1812,)
┌ Warning: At t=1699.7619999999988, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 70379.64013275309. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=6177.671000000002, dt was forced below floating point epsilon -5.12e-13, and step error estimate = 640741.6811875995. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=3443.513000000001, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 74178.86896718017. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=1641.9389999999985, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 131071.10966930584. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=6973.073999999999, dt was forced below floating point epsilon -5.12e-13, and step error estimate = 97199.2499065072. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be 
represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=3251.811999999998, dt was forced below floating point epsilon -1.0240000000000001e-13, and step error estimate = 69478.22193216026. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
┌ Warning: At t=0.6050178618171094, dt was forced below floating point epsilon 1.1102230246251565e-16, and step error estimate = 52.62783275643544. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (3379,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=2.737000497297441, dt was forced below floating point epsilon 4.440892098500626e-16, and step error estimate = 52.62786067101712. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (277,)
Dimensions of _sol[2,:]: (1,)
┌ Warning: At t=0.8401324955473453, dt was forced below floating point epsilon 1.1102230246251565e-16, and step error estimate = 52.62788972449586. Aborting. There is either an error in your model specification or the true solution is unstable (or the true solution can not be represented in the precision of Float64).
└ @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\integrator_interface.jl:623
Dimensions of T: (1750,)
Dimensions of _sol[2,:]: (2,)
ERROR: DimensionMismatch: arrays could not be broadcast to a common size: a has axes Base.OneTo(1750) and b has axes Base.OneTo(2)
Stacktrace:
  [1] _bcs1
    @ .\broadcast.jl:528 [inlined]
  [2] _bcs
    @ .\broadcast.jl:522 [inlined]
  [3] broadcast_shape
    @ .\broadcast.jl:516 [inlined]
  [4] combine_axes
    @ .\broadcast.jl:497 [inlined]
  [5] instantiate
    @ .\broadcast.jl:307 [inlined]
  [6] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{…}, Nothing, typeof(-), Tuple{…}})
    @ Base.Broadcast .\broadcast.jl:872
  [7] total_loss(θ::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Main e:\PhD Ashima\Neural ODE\Julia\IQgen_Tmixed\Updated_code_full_dataset.jl:238
  [8] (::var"#21#22")(x::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}, p::SciMLBase.NullParameters)
    @ Main e:\PhD Ashima\Neural ODE\Julia\IQgen_Tmixed\Updated_code_full_dataset.jl:249
  [9] (::OptimizationOptimJL.var"#7#13"{OptimizationCache{…}})(θ::ComponentVector{Float32, Vector{…}, Tuple{…}})
    @ OptimizationOptimJL C:\Users\Kalath_A\.julia\packages\OptimizationOptimJL\e3bUa\src\OptimizationOptimJL.jl:158
 [10] (::OptimizationOptimJL.var"#8#14"{…})(G::ComponentVector{…}, θ::ComponentVector{…})
    @ OptimizationOptimJL C:\Users\Kalath_A\.julia\packages\OptimizationOptimJL\e3bUa\src\OptimizationOptimJL.jl:171
 [11] value_gradient!!(obj::TwiceDifferentiable{…}, x::ComponentVector{…})
    @ NLSolversBase C:\Users\Kalath_A\.julia\packages\NLSolversBase\kavn7\src\interface.jl:82
 [12] value_gradient!(obj::TwiceDifferentiable{…}, x::ComponentVector{…})
    @ NLSolversBase C:\Users\Kalath_A\.julia\packages\NLSolversBase\kavn7\src\interface.jl:69
 [13] value_gradient!(obj::Optim.ManifoldObjective{TwiceDifferentiable{…}}, x::ComponentVector{Float32, Vector{…}, Tuple{…}})
    @ NLSolversBase C:\Users\Kalath_A\.julia\packages\NLSolversBase\kavn7\src\interface.jl:69
 [13] value_gradient!(obj::Optim.ManifoldObjective{TwiceDifferentiable{…}}, x::ComponentVector{Float32, Vector{…}, Tuple{…}})      
 [13] value_gradient!(obj::Optim.ManifoldObjective{TwiceDifferentiable{…}}, x::ComponentVector{Float32, Vector{…}, Tuple{…}})      
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\Manifolds.jl:50
 [14] (::LineSearches.var"#ϕdϕ#6"{…})(α::Float32)
 [14] (::LineSearches.var"#ϕdϕ#6"{…})(α::Float32)
    @ LineSearches C:\Users\Kalath_A\.julia\packages\LineSearches\jgnxK\src\LineSearches.jl:83
 [15] (::HagerZhang{…})(ϕ::Function, ϕdϕ::LineSearches.var"#ϕdϕ#6"{…}, c::Float32, phi_0::Float32, dphi_0::Float32)
    @ LineSearches C:\Users\Kalath_A\.julia\packages\LineSearches\jgnxK\src\LineSearches.jl:83
 [15] (::HagerZhang{…})(ϕ::Function, ϕdϕ::LineSearches.var"#ϕdϕ#6"{…}, c::Float32, phi_0::Float32, dphi_0::Float32)
    @ LineSearches C:\Users\Kalath_A\.julia\packages\LineSearches\jgnxK\src\hagerzhang.jl:305
 [16] HagerZhang
 [16] HagerZhang
    @ C:\Users\Kalath_A\.julia\packages\LineSearches\jgnxK\src\hagerzhang.jl:102 [inlined]
    @ C:\Users\Kalath_A\.julia\packages\LineSearches\jgnxK\src\hagerzhang.jl:102 [inlined]
 [17] perform_linesearch!(state::Optim.BFGSState{…}, method::BFGS{…}, d::Optim.ManifoldObjective{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\utilities\perform_linesearch.jl:58
 [17] perform_linesearch!(state::Optim.BFGSState{…}, method::BFGS{…}, d::Optim.ManifoldObjective{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\utilities\perform_linesearch.jl:58
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\utilities\perform_linesearch.jl:58
 [18] update_state!(d::TwiceDifferentiable{…}, state::Optim.BFGSState{…}, method::BFGS{…})
 [18] update_state!(d::TwiceDifferentiable{…}, state::Optim.BFGSState{…}, method::BFGS{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\solvers\first_order\bfgs.jl:139
 [19] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…}, state::Optim.BFGSState{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:54
 [20] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:36
 [21] __solve(cache::OptimizationCache{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\solvers\first_order\bfgs.jl:139
 [19] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…}, state::Optim.BFGSState{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:54
 [20] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:36
 [21] __solve(cache::OptimizationCache{…})
 [19] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…}, state::Optim.BFGSState{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:54
 [20] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:36
 [21] __solve(cache::OptimizationCache{…})
FGSState{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:54
 [20] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:36
 [21] __solve(cache::OptimizationCache{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:54
 [20] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:36
 [21] __solve(cache::OptimizationCache{…})
 [20] optimize(d::TwiceDifferentiable{…}, initial_x::ComponentVector{…}, method::BFGS{…}, options::Optim.Options{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:36
 [21] __solve(cache::OptimizationCache{…})
    @ Optim C:\Users\Kalath_A\.julia\packages\Optim\HvjCd\src\multivariate\optimize\optimize.jl:36
 [21] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimJL C:\Users\Kalath_A\.julia\packages\OptimizationOptimJL\e3bUa\src\OptimizationOptimJL.jl:218
    @ OptimizationOptimJL C:\Users\Kalath_A\.julia\packages\OptimizationOptimJL\e3bUa\src\OptimizationOptimJL.jl:218
 [22] solve!(cache::OptimizationCache{…})
    @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\solve.jl:187
 [23] solve(::OptimizationProblem{…}, ::BFGS{…}; kwargs::@Kwargs{…})
    @ SciMLBase C:\Users\Kalath_A\.julia\packages\SciMLBase\tWwhl\src\solve.jl:95
 [24] top-level scope
    @ e:\PhD Ashima\Neural ODE\Julia\IQgen_Tmixed\Updated_code_full_dataset.jl:251
Some type information was truncated. Use `show(err)` to see complete types.

I encountered a similar issue when using Relu and Gelu activation function without the multipilcation factor. When I added the 1e-3 factor to the initial parameter it went away. Then when I try BFGS after 700 iterations in ADAM, the issue popped up again.

Any idea why this happens? I think the problem is with ODE solving. But I don’t understand where it is going wrong

Your starting neural network is not giving an ODE that is generally stable. Usually for UDEs you want to start with a neural network that starts close to 0 or 1

Okay. But how do I ensure that? I am starting with the parameters that I got after optimization with ADAM. Wouldn’t modifying it to ensure the neural network to start close to 0 or 1 nullify that optimization?

Oh this is what the initial_stepnorm stuff helps with.