Here I describe the code where I want to train the neural network but face the problem
Here are the libraries I use.
using NeuralPDE
using DifferentialEquations
using Plots
using Lux, Random
The function system_of_de! defines a system of four coupled complex-valued differential equations. The equations describe the dynamics of a physical system with four variables u[1] , u[2] , u[3] , and u[4]. This equations name is Bloch equations.
The parameters of the system are Ω , Δ , and Γ , which are defined outside the function.
function system_of_de!(du, u, p, t)
Ω, Δ, Γ = p
γ = Γ / 2
du[1] = im * Ω * (u[3] - u[4]) + Γ * u[2]
du[2] = -im * Ω * (u[3] - u[4]) - Γ * u[2]
du[3] = -(γ + im * Δ) * u[3] - im * Ω * (u[2] - u[1])
du[4] = conj(du[3])
return nothing
end
Initial Conditions, Time Span, and Parameters:
u0 = zeros(ComplexF64, 4)
u0[1] = 1
time_span = (0.0, 7.0)
# Ω, Δ, Γ
parameters = [100.0, 0.0, 1.0]
Defining the ODE Problem
problem = ODEProblem(system_of_de!, u0, time_span, parameters)
ODEProblem with uType Vector{ComplexF64} and tType Float64. In-place: true
timespan: (0.0, 7.0)
u0: 4-element Vector{ComplexF64}:
This part involves setting up and using the Neural Network Ordinary Differential Equation (NNODE) solver.
rngis a random number generator used for initialization.chaindefines the architecture of the neural network used by NNODE.psandstare the initial parameters of the neural network.optis an optimization algorithm used to train the neural network.algis an NNODE solver object that combines the neural network and optimizer.
rng = Random.default_rng()
Random.seed!(rng, 0)
chain = Chain(Dense(1, 5, σ), Dense(5, 1))
ps, st = Lux.setup(rng, chain) |> Lux.f64
Solving the ODE:
solis the solution obtained by solving theproblemusing the NNODE solver.maxitersspecifies the maximum number of iterations the solver allows.saveatdefines how often the solution is saved during the simulation
using OptimizationOptimisers
opt = Adam(0.1)
alg = NNODE(chain, opt, init_params = ps)
sol = solve(problem, alg, verbose = true, maxiters = 2000, saveat = 0.01)
Checking the Result:
ground_truthis the solution obtained using a traditional numerical solver (Tsit5()).- The code plots both solutions for the first two variables (
u[1]andu[2]) to compare them visually.
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
plot(ground_truth.t, real.(ground_truth[1, :]), linecolor=:blue, legend = false)
plot!(ground_truth.t, real.(ground_truth[2, :]), linecolor=:blue)
plot!(sol.t, real.(sol[1, :]), linecolor=:red)
plot!(sol.t, real.(sol[2, :]), linecolor=:red)