Here I describe the code where I want to train the neural network but face the problem
Here are the libraries I use.
using NeuralPDE
using DifferentialEquations
using Plots
using Lux, Random
The function system_of_de!
defines a system of four coupled complex-valued differential equations. The equations describe the dynamics of a physical system with four variables u[1]
, u[2]
, u[3]
, and u[4]
. This equations name is Bloch equations.
The parameters of the system are Ω
, Δ
, and Γ
, which are defined outside the function.
function system_of_de!(du, u, p, t)
Ω, Δ, Γ = p
γ = Γ / 2
du[1] = im * Ω * (u[3] - u[4]) + Γ * u[2]
du[2] = -im * Ω * (u[3] - u[4]) - Γ * u[2]
du[3] = -(γ + im * Δ) * u[3] - im * Ω * (u[2] - u[1])
du[4] = conj(du[3])
return nothing
end
Initial Conditions, Time Span, and Parameters:
u0 = zeros(ComplexF64, 4)
u0[1] = 1
time_span = (0.0, 7.0)
# Ω, Δ, Γ
parameters = [100.0, 0.0, 1.0]
Defining the ODE Problem
problem = ODEProblem(system_of_de!, u0, time_span, parameters)
ODEProblem with uType Vector{ComplexF64} and tType Float64. In-place: true
timespan: (0.0, 7.0)
u0: 4-element Vector{ComplexF64}:
This part involves setting up and using the Neural Network Ordinary Differential Equation (NNODE) solver.
rng
is a random number generator used for initialization.chain
defines the architecture of the neural network used by NNODE.ps
andst
are the initial parameters of the neural network.opt
is an optimization algorithm used to train the neural network.alg
is an NNODE solver object that combines the neural network and optimizer.
rng = Random.default_rng()
Random.seed!(rng, 0)
chain = Chain(Dense(1, 5, σ), Dense(5, 1))
ps, st = Lux.setup(rng, chain) |> Lux.f64
Solving the ODE:
sol
is the solution obtained by solving theproblem
using the NNODE solver.maxiters
specifies the maximum number of iterations the solver allows.saveat
defines how often the solution is saved during the simulation
using OptimizationOptimisers
opt = Adam(0.1)
alg = NNODE(chain, opt, init_params = ps)
sol = solve(problem, alg, verbose = true, maxiters = 2000, saveat = 0.01)
Checking the Result:
ground_truth
is the solution obtained using a traditional numerical solver (Tsit5()).- The code plots both solutions for the first two variables (
u[1]
andu[2]
) to compare them visually.
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
plot(ground_truth.t, real.(ground_truth[1, :]), linecolor=:blue, legend = false)
plot!(ground_truth.t, real.(ground_truth[2, :]), linecolor=:blue)
plot!(sol.t, real.(sol[1, :]), linecolor=:red)
plot!(sol.t, real.(sol[2, :]), linecolor=:red)