Facing high loss values when training the SciML model

Greetings, i am currently working on a problem to predict the state vectors of a satellite using a dataset which is generated by a propagator. I have made a Universal Differential Equation(UDE) to predict the state vectors ( Positions and velocities). The model has worked fine until i only kept two body forces, relativistic effects, J2 and atmospheric drag in the propagator. But after i added more forces like three body forces, Solar radiation pressure etc . I see the model has started to perform very poorly. The losses are in the range of 10^12 and not decreasing with iterations. Can someone here help to troubleshoot this problem. Is it because there are new parameters which are not accounted for in the model or the current loss function is not appropriate. Can someone suggest a better UDE and loss function if this is the case? Thanks a lot in advance for the immense help and support . I am attaching the code for the refernce below.

# SciML Tool Libraries 

using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL

# Standard Julia Libraries 

using LinearAlgebra, Statistics 

# External Julia Libraries 

using ComponentArrays, Lux, Zygote, Plots, StableRNGs, LineSearches, DataFrames

# Creating a Set of Random seed for a reproducible behaviour 

rng = StableRNG(1111)

# Function defining the density at altitudes and extrapolating the values of desity at height h 

function Atm_density(z)

    # Geometric altitudes in km
    
    h = [ 0 25 30 40 50 60 70 80 90 100 110 120 130 140 150 180 200 250 300 350 400 450 500 600 700 800 900 1000];

    # Corresponding densities (kg/m^3) from the USSA76:

    r1 = [1.225 4.008e-2 1.841e-2 3.996e-3 1.027e-3 3.097e-4 8.283e-5 1.846e-5 3.416e-6 5.606e-7 9.708e-8 2.222e-8 8.152e-9 3.831e-9 2.076e-9 5.194e-10 2.541e-10 6.073e-11 1.916e-11 7.014e-12 2.803e-12 1.184e-12 5.215e-13 1.137e-13 3.070e-14 1.136e-14 5.759e-15 3.561e-15];
    
    # Scale heights in km 

    H =[ 7.310 6.427 6.546 7.360 8.342 7.583 6.661 5.927 5.533 5.703 6.782 9.973 13.243 16.322 21.652 27.974 34.934 43.342 49.755 54.513 58.019 60.980 65.654 76.377 100.587 147.203 208.020];
    
    # Handle the altitudes outseide of the range 

    if z>1000

        z = 1000;
    elseif z <0

        z = 0;
    end

    # Determine the interpolation interval

    for j in 1:27

        if z >= h[j] && z < h[j+1]

            i = j ;
        end
    end
     
    if z == 1000

        i = 27;
    end

    # Exponential interpolation :

    density =  r1[i]*exp(-(z - h[i])/H[i])

    return density
end

# Function to calculate the shadow function for the SLP

function JD_Calculator_Function(t_now)

    # Calculating the Julian Date 

    d = 12;
    m = 5;
    y = 2004;
    hh = 14;
    mm = 45;
    ss = 30;

    J0 = 367*y - trunc(7*(y+trunc((m+9)/12))/4) + trunc(275*m/9)+ d + 1721013.5;

    UT1 = hh + mm/60 + ss/3600;

    JD1 = J0 + UT1/24;

    JD = JD1 + t_now/86400

 return JD


end 

Juldate = JD_Calculator_Function(0)

# Creating a function to calualte the Sunvectors

function Sun_Vectors(Julian_Days)

    AU = 149597870.691

        n = Julian_Days - 2451545

        M = 357.529 + 0.98560023*n

        M = mod(M,360)

        L= 280.459 + 0.98564736*n

        L = mod(L,360)

        lamda = L + 1.915*sin(deg2rad(M)) + 0.0200*sin(deg2rad(2*M))
        lamda = mod(lamda,360)

        epsilon = 23.439 - 3.56*10^-7*n

        u = [cos(deg2rad(lamda)), sin(deg2rad(lamda))*cos(deg2rad(epsilon)),sin(deg2rad(lamda))*sin(deg2rad(epsilon))]
        
        r_sun_mag = (1.00014 - 0.01671*cos(deg2rad(M)) - 0.000140*cos(deg2rad(2*M)))*AU

        r_sun = r_sun_mag*u


    return r_sun,u

end

# Function to calculate the shadow function of the Solar Raditiob pressure 

function Shadow_Function(r,sun_r)

    RE = 6378

    rsat = norm(r)

    rsun = norm(sun_r)

    theta = rad2deg(acos(dot(r,sun_r)/(rsat*rsun)))

    theta_sat = rad2deg(acos(RE/rsat))

    theta_sun = rad2deg(acos(RE/rsun))

    if theta_sat + theta_sun <= theta

        light_switch = 0;

    else 

        light_switch = 1;

    end

    return light_switch


end

# Function to calculate the lunar position relative to earth 

function Lunar_Position_Calculator(Julian_Date)

    RE = 6378

    T = (Julian_Date - 2451545.0)/36525

    #Ecliptic longitude (deg):

    e_long = 218.32 + 481267.881*T + 6.29*sind(135.0 + 477198.87*T)- 1.27*sind(259.3- 413335.36*T)
    + 0.66*sind(235.7 + 890534.22*T) + 0.21*sind(269.9 + 954397.74*T)- 0.19*sind(357.5 +35999.05*T)- 0.11*sind(186.5 + 966404.03*T);

    e_long = mod(e_long,360);

    #Ecliptic latitude (deg):

    e_lat = 5.13*sind( 93.3 + 483202.02*T) + 0.28*sind(228.2 + 960400.89*T) - 0.28*sind(318.3 +6003.15*T)- 0.17*sind(217.6- 407332.21*T);

    e_lat = mod(e_lat,360);

    #Horizontal parallax (deg):

     h_par = 0.9508 + 0.0518*cosd(135.0 + 477198.87*T) + 0.0095*cosd(259.3- 413335.36*T) + 0.0078*cosd(235.7 + 890534.22*T) + 0.0028*cosd(269.9 + 954397.74*T);

     h_par = mod(h_par,360);

    #Angle between earth’s orbit and its equator (deg):

     obliquity = 23.439291- 0.0130042*T;

    # Direction cosines of the moon’s geocentric equatorial position vector:

     l = cosd(e_lat) * cosd(e_long);

     m = cosd(obliquity)*cosd(e_lat)*sind(e_long)- sind(obliquity)*sind(e_lat);

     n = sind(obliquity)*cosd(e_lat)*sind(e_long) + cosd(obliquity)*sind(e_lat);

    #Earth-moon distance (km):

     dist = RE/sind(h_par);

    #Moon’s geocentric equatorial position vector (km):

     r_moon = dist*[l,m,n];

    return r_moon

end

moon_pos = Lunar_Position_Calculator(Juldate)

# ODE model for Satellite Propagation around Earth with J2 and Relativistic Perturbations 


function Propagator_org_1(u, model_params, t)

    #= Defining the variables of the system to store as symbols in array u

       u[1] = x 
       u[2] = y
       u[3] = z
       u[4] = vx 
       u[5] = vy
       u[6] = vz 

       here μ, R , J2 , c are the constants 

    =#

    x, y, z, vx, vy, vz = u
    μ, R, J2, c, CD, A, m, S, CR, As,μmoon  = model_params

    r = sqrt(x^2 + y^2 + z^2)

    Ω = [ 0, 0, 7.2921159e-5]

    rvect = [x, y, z]

    v = [vx, vy, vz]

    vrel = v - cross(Ω,rvect)

    vrel_mag = norm(vrel)

    alt = r - R

    # Get the Julian Day for the corresponding time 

    Jul_D = JD_Calculator_Function(t)

    # Getting Sun vector and unit vector in sun direction

    sun_vectors = Sun_Vectors(Jul_D)[1]

    u_vectors = Sun_Vectors(Jul_D)[2]

    moon_vector = Lunar_Position_Calculator(Jul_D)

    moon_mag = norm(moon_vector)

    moon_rel = moon_vector - rvect

    moon_rel_mag = norm(moon_vector-rvect)

    # Function F 

    

    # Define the equations of motion

       xdot = vx
       ydot = vy
       zdot = vz

      vxdot = -μ * (x / r^3) + ((3/2) * J2 * (μ / r^2) * (R / r)^2 * (x / r)) * (5 * (z / r)^2 - 1) + 
                  (μ / (c^2 * r^3)) * ((4 * μ / r) * x - (vx^2 + vy^2 + vz^2) * x + 4 * (x * vx + y * vy + z * vz) * vx) + (1/2)*Atm_density(alt)*vrel_mag*(CD*A/m)*vrel[1] + Shadow_Function(rvect,sun_vectors)*(S/(m*c))*CR*As*u_vectors[1] +
                  μmoon*(moon_rel[1]/moon_rel_mag^3 - moon_vector[1]/moon_mag^3) 

      vydot = -μ * (y / r^3) + ((3/2) * J2 * (μ / r^2) * (R / r)^2 * (y / r)) * (5 * (z / r)^2 - 1) + 
                  (μ / (c^2 * r^3)) * ((4 * μ / r) * y - (vx^2 + vy^2 + vz^2) * y + 4 * (x * vx + y * vy + z * vz) * vy) + (1/2)*Atm_density(alt)*vrel_mag*(CD*A/m)*vrel[2] + Shadow_Function(rvect,sun_vectors)*(S/(m*c))*CR*As*u_vectors[2] +
                  μmoon*(moon_rel[2]/moon_rel_mag^3 - moon_vector[2]/moon_mag^3)


      vzdot = -μ * (z / r^3) + ((3/2) * J2 * (μ / r^2) * (R / r)^2 * (z / r)) * (5 * (z / r)^2 - 3) + 
                  (μ / (c^2 * r^3)) * ((4 * μ / r) * z - (vx^2 + vy^2 + vz^2) * z + 4 * (x * vx + y * vy + z * vz) * vz) + (1/2)*Atm_density(alt)*vrel_mag*(CD*A/m)*vrel[3] + Shadow_Function(rvect,sun_vectors)*(S/(m*c))*CR*As*u_vectors[3] +
                  μmoon*(moon_rel[3]/moon_rel_mag^3 - moon_vector[3]/moon_mag^3)


    return [ xdot, ydot, zdot, vxdot, vydot, vzdot]
end

# Defining thr time interval for the numerical solution 

tspan = (0.0, 5000.0)

ts = range(0, step=100.0, stop=5000.0)

# initial conditions 

u0 = [5873.40, -658.522, 3007.49, -2.8964, 4.94010, 6.14446]

# Model parameters/constants 

p1 = [398600, 6378, 1082.63e-6, 299800, 2.2, 3.1416e-6, 100, 1367, 1, 3.1416,4903]

# Building the ODE broblem with the given inital & parameter values 

problem = ODEProblem(Propagator_org_1,u0,tspan,p1)

# solving the created ODE problem numerically 

soln = solve(problem, Vern7(),abstol=1e-12, reltol=1e-12, saveat=ts)

t1 = soln.t

X = Array(soln)

# Plotting the results of the solutionin graphical format 

plot(soln, xlabel="Time (s)", ylabel="Position (km)", label=["x" "y" "z"])

# Creating the Neural Network for UDE


const U = Lux.Chain(
   Lux.Dense(4 => 32, Lux.tanh, use_bias = true),
   Lux.Dense(32 => 32, Lux.tanh, use_bias = true),
   Lux.Dense(32 => 32, Lux.tanh, use_bias = true),
   Lux.Dense(32 => 3, use_bias = true)
)
 
# Get the initial parameters and state variables of the model 

p, st = Lux.setup(rng,U)

const _st = st

# Defining the universal differential Equation 

function ude_dynamics!(du,u,p,t,p_true)

   x, y, z, vx, vy, vz = u

   # xdot, ydot, zdot, vxdot, vydot, vzdot = du

   μ, R, J2, c = p_true

   r = sqrt(x^2 + y^2 + z^2)

   alt = r - R

   ucap = U([x,y,z,Atm_density(alt)],p,_st)[1]


   du[1] = vx
   du[2] = vy
   du[3] = vz



   du[4] = -μ * (x / r^3) + ucap[1]

   du[5] = -μ * (y / r^3) + ucap[2]

   du[6] = -μ * (z / r^3) + ucap[3]
 

 #=
 
 du[4] = -μ *Atm_density(alt)*((x + ucap[1]) / r^3) 

 du[5] = -μ *Atm_density(alt)*((y + ucap[2]) / r^3)

 du[6] = -μ *Atm_density(alt)*((z + ucap[3]) / r^3)
 
 =#

end 

# Matching with the known parameters 

nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p1)

# Defining the neural ODE Problem 

prob_nn = ODEProblem(nn_dynamics!,X[:,1],tspan,p)

# Defining the predict function 

function predict(θ, X = X[:,1], T = t1)

   _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
   Array(solve(_prob, Vern7(), saveat = T,
       abstol = 1e-6, reltol = 1e-6,
       sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
end

# Defining the loss function

function loss(θ)
   X̂ = predict(θ)
   mean(abs2, X .- X̂)
end

# Defining the callback functions to call in the training loop

losses = Float64[]

callback = function (state, l)
    push!(losses, l)
    if length(losses) % 50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    return false
end

# Optimizing the neural network to larn the model 

adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))

# Training loop

res1 = Optimization.solve(
    optprob, OptimizationOptimisers.Adam(1e-4), callback = callback, maxiters = 20000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(
    optprob2, LBFGS(linesearch = BackTracking()), callback = callback, maxiters = 1000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Rename the best candidate
p_trained = res2.u

# Plotting the results of ude in graph

tspan1 = (0.0, 15000.0)

ts1 = range(0, step=100.0, stop=15000.0)

# initial conditions 

u0 = [5873.40, -658.522, 3007.49, -2.8964, 4.94010, 6.14446]

# Model parameters/constants 

p1 = [398600, 6378, 1082.63e-6, 299800, 2.2, 3.1416e-6, 100]

# Building the ODE broblem with the given inital & parameter values 

problem1 = ODEProblem(Propagator_org_1,u0,tspan1,p1)

# solving the created ODE problem numerically 

soln1 = solve(problem1, Vern7(), abstol=1e-12, reltol=1e-12, saveat=ts1)

t11 = soln1.t

Xcap = predict(p_trained,X[:,1],t11)

pl_trajectory = plot(t11,transpose(Xcap),color=:red,label = ["UDE Approximation" nothing])

X1 = Array(soln1)

scatter!(soln1.t, transpose(X1), color = :black, label = ["Measurements" nothing])