Greetings, i am currently working on a problem to predict the state vectors of a satellite using a dataset which is generated by a propagator. I have made a Universal Differential Equation(UDE) to predict the state vectors ( Positions and velocities). The model has worked fine until i only kept two body forces, relativistic effects, J2 and atmospheric drag in the propagator. But after i added more forces like three body forces, Solar radiation pressure etc . I see the model has started to perform very poorly. The losses are in the range of 10^12 and not decreasing with iterations. Can someone here help to troubleshoot this problem. Is it because there are new parameters which are not accounted for in the model or the current loss function is not appropriate. Can someone suggest a better UDE and loss function if this is the case? Thanks a lot in advance for the immense help and support . I am attaching the code for the refernce below.
# SciML Tool Libraries
using OrdinaryDiffEq, ModelingToolkit, DataDrivenDiffEq, SciMLSensitivity, DataDrivenSparse
using Optimization, OptimizationOptimisers, OptimizationOptimJL
# Standard Julia Libraries
using LinearAlgebra, Statistics
# External Julia Libraries
using ComponentArrays, Lux, Zygote, Plots, StableRNGs, LineSearches, DataFrames
# Creating a Set of Random seed for a reproducible behaviour
rng = StableRNG(1111)
# Function defining the density at altitudes and extrapolating the values of desity at height h
function Atm_density(z)
# Geometric altitudes in km
h = [ 0 25 30 40 50 60 70 80 90 100 110 120 130 140 150 180 200 250 300 350 400 450 500 600 700 800 900 1000];
# Corresponding densities (kg/m^3) from the USSA76:
r1 = [1.225 4.008e-2 1.841e-2 3.996e-3 1.027e-3 3.097e-4 8.283e-5 1.846e-5 3.416e-6 5.606e-7 9.708e-8 2.222e-8 8.152e-9 3.831e-9 2.076e-9 5.194e-10 2.541e-10 6.073e-11 1.916e-11 7.014e-12 2.803e-12 1.184e-12 5.215e-13 1.137e-13 3.070e-14 1.136e-14 5.759e-15 3.561e-15];
# Scale heights in km
H =[ 7.310 6.427 6.546 7.360 8.342 7.583 6.661 5.927 5.533 5.703 6.782 9.973 13.243 16.322 21.652 27.974 34.934 43.342 49.755 54.513 58.019 60.980 65.654 76.377 100.587 147.203 208.020];
# Handle the altitudes outseide of the range
if z>1000
z = 1000;
elseif z <0
z = 0;
end
# Determine the interpolation interval
for j in 1:27
if z >= h[j] && z < h[j+1]
i = j ;
end
end
if z == 1000
i = 27;
end
# Exponential interpolation :
density = r1[i]*exp(-(z - h[i])/H[i])
return density
end
# Function to calculate the shadow function for the SLP
function JD_Calculator_Function(t_now)
# Calculating the Julian Date
d = 12;
m = 5;
y = 2004;
hh = 14;
mm = 45;
ss = 30;
J0 = 367*y - trunc(7*(y+trunc((m+9)/12))/4) + trunc(275*m/9)+ d + 1721013.5;
UT1 = hh + mm/60 + ss/3600;
JD1 = J0 + UT1/24;
JD = JD1 + t_now/86400
return JD
end
Juldate = JD_Calculator_Function(0)
# Creating a function to calualte the Sunvectors
function Sun_Vectors(Julian_Days)
AU = 149597870.691
n = Julian_Days - 2451545
M = 357.529 + 0.98560023*n
M = mod(M,360)
L= 280.459 + 0.98564736*n
L = mod(L,360)
lamda = L + 1.915*sin(deg2rad(M)) + 0.0200*sin(deg2rad(2*M))
lamda = mod(lamda,360)
epsilon = 23.439 - 3.56*10^-7*n
u = [cos(deg2rad(lamda)), sin(deg2rad(lamda))*cos(deg2rad(epsilon)),sin(deg2rad(lamda))*sin(deg2rad(epsilon))]
r_sun_mag = (1.00014 - 0.01671*cos(deg2rad(M)) - 0.000140*cos(deg2rad(2*M)))*AU
r_sun = r_sun_mag*u
return r_sun,u
end
# Function to calculate the shadow function of the Solar Raditiob pressure
function Shadow_Function(r,sun_r)
RE = 6378
rsat = norm(r)
rsun = norm(sun_r)
theta = rad2deg(acos(dot(r,sun_r)/(rsat*rsun)))
theta_sat = rad2deg(acos(RE/rsat))
theta_sun = rad2deg(acos(RE/rsun))
if theta_sat + theta_sun <= theta
light_switch = 0;
else
light_switch = 1;
end
return light_switch
end
# Function to calculate the lunar position relative to earth
function Lunar_Position_Calculator(Julian_Date)
RE = 6378
T = (Julian_Date - 2451545.0)/36525
#Ecliptic longitude (deg):
e_long = 218.32 + 481267.881*T + 6.29*sind(135.0 + 477198.87*T)- 1.27*sind(259.3- 413335.36*T)
+ 0.66*sind(235.7 + 890534.22*T) + 0.21*sind(269.9 + 954397.74*T)- 0.19*sind(357.5 +35999.05*T)- 0.11*sind(186.5 + 966404.03*T);
e_long = mod(e_long,360);
#Ecliptic latitude (deg):
e_lat = 5.13*sind( 93.3 + 483202.02*T) + 0.28*sind(228.2 + 960400.89*T) - 0.28*sind(318.3 +6003.15*T)- 0.17*sind(217.6- 407332.21*T);
e_lat = mod(e_lat,360);
#Horizontal parallax (deg):
h_par = 0.9508 + 0.0518*cosd(135.0 + 477198.87*T) + 0.0095*cosd(259.3- 413335.36*T) + 0.0078*cosd(235.7 + 890534.22*T) + 0.0028*cosd(269.9 + 954397.74*T);
h_par = mod(h_par,360);
#Angle between earth’s orbit and its equator (deg):
obliquity = 23.439291- 0.0130042*T;
# Direction cosines of the moon’s geocentric equatorial position vector:
l = cosd(e_lat) * cosd(e_long);
m = cosd(obliquity)*cosd(e_lat)*sind(e_long)- sind(obliquity)*sind(e_lat);
n = sind(obliquity)*cosd(e_lat)*sind(e_long) + cosd(obliquity)*sind(e_lat);
#Earth-moon distance (km):
dist = RE/sind(h_par);
#Moon’s geocentric equatorial position vector (km):
r_moon = dist*[l,m,n];
return r_moon
end
moon_pos = Lunar_Position_Calculator(Juldate)
# ODE model for Satellite Propagation around Earth with J2 and Relativistic Perturbations
function Propagator_org_1(u, model_params, t)
#= Defining the variables of the system to store as symbols in array u
u[1] = x
u[2] = y
u[3] = z
u[4] = vx
u[5] = vy
u[6] = vz
here μ, R , J2 , c are the constants
=#
x, y, z, vx, vy, vz = u
μ, R, J2, c, CD, A, m, S, CR, As,μmoon = model_params
r = sqrt(x^2 + y^2 + z^2)
Ω = [ 0, 0, 7.2921159e-5]
rvect = [x, y, z]
v = [vx, vy, vz]
vrel = v - cross(Ω,rvect)
vrel_mag = norm(vrel)
alt = r - R
# Get the Julian Day for the corresponding time
Jul_D = JD_Calculator_Function(t)
# Getting Sun vector and unit vector in sun direction
sun_vectors = Sun_Vectors(Jul_D)[1]
u_vectors = Sun_Vectors(Jul_D)[2]
moon_vector = Lunar_Position_Calculator(Jul_D)
moon_mag = norm(moon_vector)
moon_rel = moon_vector - rvect
moon_rel_mag = norm(moon_vector-rvect)
# Function F
# Define the equations of motion
xdot = vx
ydot = vy
zdot = vz
vxdot = -μ * (x / r^3) + ((3/2) * J2 * (μ / r^2) * (R / r)^2 * (x / r)) * (5 * (z / r)^2 - 1) +
(μ / (c^2 * r^3)) * ((4 * μ / r) * x - (vx^2 + vy^2 + vz^2) * x + 4 * (x * vx + y * vy + z * vz) * vx) + (1/2)*Atm_density(alt)*vrel_mag*(CD*A/m)*vrel[1] + Shadow_Function(rvect,sun_vectors)*(S/(m*c))*CR*As*u_vectors[1] +
μmoon*(moon_rel[1]/moon_rel_mag^3 - moon_vector[1]/moon_mag^3)
vydot = -μ * (y / r^3) + ((3/2) * J2 * (μ / r^2) * (R / r)^2 * (y / r)) * (5 * (z / r)^2 - 1) +
(μ / (c^2 * r^3)) * ((4 * μ / r) * y - (vx^2 + vy^2 + vz^2) * y + 4 * (x * vx + y * vy + z * vz) * vy) + (1/2)*Atm_density(alt)*vrel_mag*(CD*A/m)*vrel[2] + Shadow_Function(rvect,sun_vectors)*(S/(m*c))*CR*As*u_vectors[2] +
μmoon*(moon_rel[2]/moon_rel_mag^3 - moon_vector[2]/moon_mag^3)
vzdot = -μ * (z / r^3) + ((3/2) * J2 * (μ / r^2) * (R / r)^2 * (z / r)) * (5 * (z / r)^2 - 3) +
(μ / (c^2 * r^3)) * ((4 * μ / r) * z - (vx^2 + vy^2 + vz^2) * z + 4 * (x * vx + y * vy + z * vz) * vz) + (1/2)*Atm_density(alt)*vrel_mag*(CD*A/m)*vrel[3] + Shadow_Function(rvect,sun_vectors)*(S/(m*c))*CR*As*u_vectors[3] +
μmoon*(moon_rel[3]/moon_rel_mag^3 - moon_vector[3]/moon_mag^3)
return [ xdot, ydot, zdot, vxdot, vydot, vzdot]
end
# Defining thr time interval for the numerical solution
tspan = (0.0, 5000.0)
ts = range(0, step=100.0, stop=5000.0)
# initial conditions
u0 = [5873.40, -658.522, 3007.49, -2.8964, 4.94010, 6.14446]
# Model parameters/constants
p1 = [398600, 6378, 1082.63e-6, 299800, 2.2, 3.1416e-6, 100, 1367, 1, 3.1416,4903]
# Building the ODE broblem with the given inital & parameter values
problem = ODEProblem(Propagator_org_1,u0,tspan,p1)
# solving the created ODE problem numerically
soln = solve(problem, Vern7(),abstol=1e-12, reltol=1e-12, saveat=ts)
t1 = soln.t
X = Array(soln)
# Plotting the results of the solutionin graphical format
plot(soln, xlabel="Time (s)", ylabel="Position (km)", label=["x" "y" "z"])
# Creating the Neural Network for UDE
const U = Lux.Chain(
Lux.Dense(4 => 32, Lux.tanh, use_bias = true),
Lux.Dense(32 => 32, Lux.tanh, use_bias = true),
Lux.Dense(32 => 32, Lux.tanh, use_bias = true),
Lux.Dense(32 => 3, use_bias = true)
)
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng,U)
const _st = st
# Defining the universal differential Equation
function ude_dynamics!(du,u,p,t,p_true)
x, y, z, vx, vy, vz = u
# xdot, ydot, zdot, vxdot, vydot, vzdot = du
μ, R, J2, c = p_true
r = sqrt(x^2 + y^2 + z^2)
alt = r - R
ucap = U([x,y,z,Atm_density(alt)],p,_st)[1]
du[1] = vx
du[2] = vy
du[3] = vz
du[4] = -μ * (x / r^3) + ucap[1]
du[5] = -μ * (y / r^3) + ucap[2]
du[6] = -μ * (z / r^3) + ucap[3]
#=
du[4] = -μ *Atm_density(alt)*((x + ucap[1]) / r^3)
du[5] = -μ *Atm_density(alt)*((y + ucap[2]) / r^3)
du[6] = -μ *Atm_density(alt)*((z + ucap[3]) / r^3)
=#
end
# Matching with the known parameters
nn_dynamics!(du, u, p, t) = ude_dynamics!(du, u, p, t, p1)
# Defining the neural ODE Problem
prob_nn = ODEProblem(nn_dynamics!,X[:,1],tspan,p)
# Defining the predict function
function predict(θ, X = X[:,1], T = t1)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
Array(solve(_prob, Vern7(), saveat = T,
abstol = 1e-6, reltol = 1e-6,
sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true))))
end
# Defining the loss function
function loss(θ)
X̂ = predict(θ)
mean(abs2, X .- X̂)
end
# Defining the callback functions to call in the training loop
losses = Float64[]
callback = function (state, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
return false
end
# Optimizing the neural network to larn the model
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
# Training loop
res1 = Optimization.solve(
optprob, OptimizationOptimisers.Adam(1e-4), callback = callback, maxiters = 20000)
println("Training loss after $(length(losses)) iterations: $(losses[end])")
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(
optprob2, LBFGS(linesearch = BackTracking()), callback = callback, maxiters = 1000)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
# Rename the best candidate
p_trained = res2.u
# Plotting the results of ude in graph
tspan1 = (0.0, 15000.0)
ts1 = range(0, step=100.0, stop=15000.0)
# initial conditions
u0 = [5873.40, -658.522, 3007.49, -2.8964, 4.94010, 6.14446]
# Model parameters/constants
p1 = [398600, 6378, 1082.63e-6, 299800, 2.2, 3.1416e-6, 100]
# Building the ODE broblem with the given inital & parameter values
problem1 = ODEProblem(Propagator_org_1,u0,tspan1,p1)
# solving the created ODE problem numerically
soln1 = solve(problem1, Vern7(), abstol=1e-12, reltol=1e-12, saveat=ts1)
t11 = soln1.t
Xcap = predict(p_trained,X[:,1],t11)
pl_trajectory = plot(t11,transpose(Xcap),color=:red,label = ["UDE Approximation" nothing])
X1 = Array(soln1)
scatter!(soln1.t, transpose(X1), color = :black, label = ["Measurements" nothing])