How can I solve complex-valued ordinary differential equations (ODEs) using neural networks, given limitations with complex data types in libraries like Lux?

Here I describe the code where I want to train the neural network but face the problem

Here are the libraries I use.

using NeuralPDE
using DifferentialEquations
using Plots
using Lux, Random

The function system_of_de! defines a system of four coupled complex-valued differential equations. The equations describe the dynamics of a physical system with four variables u[1] , u[2] , u[3] , and u[4]. This equations name is Bloch equations.

The parameters of the system are Ω , Δ , and Γ , which are defined outside the function.

function system_of_de!(du, u, p, t)
    Ω, Δ, Γ = p
    γ = Γ / 2

    du[1] =  im * Ω * (u[3] - u[4]) + Γ * u[2]
    du[2] = -im * Ω * (u[3] - u[4]) - Γ * u[2]
    du[3] = -(γ + im * Δ) * u[3] - im * Ω * (u[2] - u[1])
    du[4] = conj(du[3])

    return nothing
end

Initial Conditions, Time Span, and Parameters:

u0 = zeros(ComplexF64, 4)
u0[1] = 1

time_span = (0.0, 7.0)

#             Ω,     Δ,   Γ         
parameters = [100.0, 0.0, 1.0]

Defining the ODE Problem

problem = ODEProblem(system_of_de!, u0, time_span, parameters)
ODEProblem with uType Vector{ComplexF64} and tType Float64. In-place: true
timespan: (0.0, 7.0)
u0: 4-element Vector{ComplexF64}:

This part involves setting up and using the Neural Network Ordinary Differential Equation (NNODE) solver.

  • rng is a random number generator used for initialization.
  • chain defines the architecture of the neural network used by NNODE.
  • ps and st are the initial parameters of the neural network.
  • opt is an optimization algorithm used to train the neural network.
  • alg is an NNODE solver object that combines the neural network and optimizer.
rng = Random.default_rng()
Random.seed!(rng, 0)
chain = Chain(Dense(1, 5, σ), Dense(5, 1))
ps, st = Lux.setup(rng, chain) |> Lux.f64

Solving the ODE:

  • sol is the solution obtained by solving the problem using the NNODE solver.
  • maxiters specifies the maximum number of iterations the solver allows.
  • saveat defines how often the solution is saved during the simulation
using OptimizationOptimisers

opt = Adam(0.1)
alg = NNODE(chain, opt, init_params = ps)

sol = solve(problem, alg, verbose = true, maxiters = 2000, saveat = 0.01)

Checking the Result:

  • ground_truth is the solution obtained using a traditional numerical solver (Tsit5()).
  • The code plots both solutions for the first two variables (u[1] and u[2]) to compare them visually.
ground_truth  = solve(problem, Tsit5(), saveat = 0.01)    

plot(ground_truth.t, real.(ground_truth[1, :]), linecolor=:blue, legend = false)
plot!(ground_truth.t, real.(ground_truth[2, :]), linecolor=:blue)

plot!(sol.t, real.(sol[1, :]), linecolor=:red)
plot!(sol.t, real.(sol[2, :]), linecolor=:red)

Without running the code: Is there a question or a particular error that occurs?

At the moment it looks like a tutorial on how to do it. Also, one trivial (but maybe unsatisfying) approach would be to convert the ODE to a real-valued one.

This is the issue How can I solve complex-valued ordinary differential equations (ODEs) using neural networks, given limitations with complex data types in libraries like Lux? · Issue #818 · SciML/NeuralPDE.jl · GitHub, if interested track it there.