First of all, thank you to the people who developed Julia and DifferentialEquations.jl these both are really great tools!
Now, cutting to the chase:
I need to solve a large system of non-linear complex-valued ODEs (N~1000). For now, I am testing my code on a small system, and currently am battling with the memory allocation issue, which I can not improve despite following (hopefully, correctly) performance improvement tutorials. MWE:
module Dif
using DifferentialEquations, JLD2, StaticArrays
const tspan=(0.0,100.0)
const t_name=tspan[2]
function Nodes(u,p,t) #equations to be solved. For large N, they get very long and ugly
A1=((9*((p[3])^2/2)-p[2]/2*abs2(p[1]))*u[1]-p[2]/2*(p[1])^2*conj(u[7]) - (p[2]/2)*(2*p[1]*u[1]*conj(u[4])+conj(p[1])*u[1]*u[4]+2*p[1]*u[2]*conj(u[5])+conj(p[1])*u[2]*u[3]+2*p[1]*u[3]*conj(u[6])+conj(p[1])*u[3]*u[2]+2*p[1]*u[4]*conj(u[7])+conj(p[1])*u[4]*u[1] + u[1]*conj(u[1])*u[1]+u[1]*conj(u[2])*u[2]+u[2]*conj(u[2])*u[1]+u[1]*conj(u[3])*u[3]+u[2]*conj(u[3])*u[2]+u[3]*conj(u[3])*u[1]+u[1]*conj(u[4])*u[4]+u[2]*conj(u[4])*u[3]+u[3]*conj(u[4])*u[2]+u[4]*conj(u[4])*u[1]+u[1]*conj(u[5])*u[5]+u[2]*conj(u[5])*u[4]+u[3]*conj(u[5])*u[3]+u[4]*conj(u[5])*u[2]+u[5]*conj(u[5])*u[1]+u[1]*conj(u[6])*u[6]+u[2]*conj(u[6])*u[5]+u[3]*conj(u[6])*u[4]+u[4]*conj(u[6])*u[3]+u[5]*conj(u[6])*u[2]+u[6]*conj(u[6])*u[1]+u[1]*conj(u[7])*u[7]+u[2]*conj(u[7])*u[6]+u[3]*conj(u[7])*u[5]+u[4]*conj(u[7])*u[4]+u[5]*conj(u[7])*u[3]+u[6]*conj(u[7])*u[2]+u[7]*conj(u[7])*u[1]))/im
A2=((4*((p[3])^2/2)-p[2]/2*abs2(p[1]))*u[2]-p[2]/2*(p[1])^2*conj(u[6]) - (p[2]/2)*(2*p[1]*u[1]*conj(u[3])+conj(p[1])*u[1]*u[5]+2*p[1]*u[2]*conj(u[4])+conj(p[1])*u[2]*u[4]+2*p[1]*u[3]*conj(u[5])+conj(p[1])*u[3]*u[3]+2*p[1]*u[4]*conj(u[6])+conj(p[1])*u[4]*u[2]+2*p[1]*u[5]*conj(u[7])+conj(p[1])*u[5]*u[1] + u[1]*conj(u[1])*u[2]+u[2]*conj(u[1])*u[1]+u[1]*conj(u[2])*u[3]+u[2]*conj(u[2])*u[2]+u[3]*conj(u[2])*u[1]+u[1]*conj(u[3])*u[4]+u[2]*conj(u[3])*u[3]+u[3]*conj(u[3])*u[2]+u[4]*conj(u[3])*u[1]+u[1]*conj(u[4])*u[5]+u[2]*conj(u[4])*u[4]+u[3]*conj(u[4])*u[3]+u[4]*conj(u[4])*u[2]+u[5]*conj(u[4])*u[1]+u[1]*conj(u[5])*u[6]+u[2]*conj(u[5])*u[5]+u[3]*conj(u[5])*u[4]+u[4]*conj(u[5])*u[3]+u[5]*conj(u[5])*u[2]+u[6]*conj(u[5])*u[1]+u[1]*conj(u[6])*u[7]+u[2]*conj(u[6])*u[6]+u[3]*conj(u[6])*u[5]+u[4]*conj(u[6])*u[4]+u[5]*conj(u[6])*u[3]+u[6]*conj(u[6])*u[2]+u[7]*conj(u[6])*u[1]+u[2]*conj(u[7])*u[7]+u[3]*conj(u[7])*u[6]+u[4]*conj(u[7])*u[5]+u[5]*conj(u[7])*u[4]+u[6]*conj(u[7])*u[3]+u[7]*conj(u[7])*u[2]))/im
A3=((1*((p[3])^2/2)-p[2]/2*abs2(p[1]))*u[3]-p[2]/2*(p[1])^2*conj(u[5]) - (p[2]/2)*(2*p[1]*u[1]*conj(u[2])+conj(p[1])*u[1]*u[6]+2*p[1]*u[2]*conj(u[3])+conj(p[1])*u[2]*u[5]+2*p[1]*u[3]*conj(u[4])+conj(p[1])*u[3]*u[4]+2*p[1]*u[4]*conj(u[5])+conj(p[1])*u[4]*u[3]+2*p[1]*u[5]*conj(u[6])+conj(p[1])*u[5]*u[2]+2*p[1]*u[6]*conj(u[7])+conj(p[1])*u[6]*u[1] + u[1]*conj(u[1])*u[3]+u[2]*conj(u[1])*u[2]+u[3]*conj(u[1])*u[1]+u[1]*conj(u[2])*u[4]+u[2]*conj(u[2])*u[3]+u[3]*conj(u[2])*u[2]+u[4]*conj(u[2])*u[1]+u[1]*conj(u[3])*u[5]+u[2]*conj(u[3])*u[4]+u[3]*conj(u[3])*u[3]+u[4]*conj(u[3])*u[2]+u[5]*conj(u[3])*u[1]+u[1]*conj(u[4])*u[6]+u[2]*conj(u[4])*u[5]+u[3]*conj(u[4])*u[4]+u[4]*conj(u[4])*u[3]+u[5]*conj(u[4])*u[2]+u[6]*conj(u[4])*u[1]+u[1]*conj(u[5])*u[7]+u[2]*conj(u[5])*u[6]+u[3]*conj(u[5])*u[5]+u[4]*conj(u[5])*u[4]+u[5]*conj(u[5])*u[3]+u[6]*conj(u[5])*u[2]+u[7]*conj(u[5])*u[1]+u[2]*conj(u[6])*u[7]+u[3]*conj(u[6])*u[6]+u[4]*conj(u[6])*u[5]+u[5]*conj(u[6])*u[4]+u[6]*conj(u[6])*u[3]+u[7]*conj(u[6])*u[2]+u[3]*conj(u[7])*u[7]+u[4]*conj(u[7])*u[6]+u[5]*conj(u[7])*u[5]+u[6]*conj(u[7])*u[4]+u[7]*conj(u[7])*u[3]))/im
A4=((-p[2]/2*abs2(p[1]))*u[4]-p[2]/2*(p[1])^2*conj(u[4]) - (p[2]/2)*(2*p[1]*u[1]*conj(u[1])+conj(p[1])*u[1]*u[7]+2*p[1]*u[2]*conj(u[2])+conj(p[1])*u[2]*u[6]+2*p[1]*u[3]*conj(u[3])+conj(p[1])*u[3]*u[5]+2*p[1]*u[4]*conj(u[4])+conj(p[1])*u[4]*u[4]+2*p[1]*u[5]*conj(u[5])+conj(p[1])*u[5]*u[3]+2*p[1]*u[6]*conj(u[6])+conj(p[1])*u[6]*u[2]+2*p[1]*u[7]*conj(u[7])+conj(p[1])*u[7]*u[1] + u[1]*conj(u[1])*u[4]+u[2]*conj(u[1])*u[3]+u[3]*conj(u[1])*u[2]+u[4]*conj(u[1])*u[1]+u[1]*conj(u[2])*u[5]+u[2]*conj(u[2])*u[4]+u[3]*conj(u[2])*u[3]+u[4]*conj(u[2])*u[2]+u[5]*conj(u[2])*u[1]+u[1]*conj(u[3])*u[6]+u[2]*conj(u[3])*u[5]+u[3]*conj(u[3])*u[4]+u[4]*conj(u[3])*u[3]+u[5]*conj(u[3])*u[2]+u[6]*conj(u[3])*u[1]+u[1]*conj(u[4])*u[7]+u[2]*conj(u[4])*u[6]+u[3]*conj(u[4])*u[5]+u[4]*conj(u[4])*u[4]+u[5]*conj(u[4])*u[3]+u[6]*conj(u[4])*u[2]+u[7]*conj(u[4])*u[1]+u[2]*conj(u[5])*u[7]+u[3]*conj(u[5])*u[6]+u[4]*conj(u[5])*u[5]+u[5]*conj(u[5])*u[4]+u[6]*conj(u[5])*u[3]+u[7]*conj(u[5])*u[2]+u[3]*conj(u[6])*u[7]+u[4]*conj(u[6])*u[6]+u[5]*conj(u[6])*u[5]+u[6]*conj(u[6])*u[4]+u[7]*conj(u[6])*u[3]+u[4]*conj(u[7])*u[7]+u[5]*conj(u[7])*u[6]+u[6]*conj(u[7])*u[5]+u[7]*conj(u[7])*u[4]))/im
A5=((1*((p[3])^2/2)-p[2]/2*abs2(p[1]))*u[5]-p[2]/2*(p[1])^2*conj(u[3]) - (p[2]/2)*(2*p[1]*u[2]*conj(u[1])+conj(p[1])*u[2]*u[7]+2*p[1]*u[3]*conj(u[2])+conj(p[1])*u[3]*u[6]+2*p[1]*u[4]*conj(u[3])+conj(p[1])*u[4]*u[5]+2*p[1]*u[5]*conj(u[4])+conj(p[1])*u[5]*u[4]+2*p[1]*u[6]*conj(u[5])+conj(p[1])*u[6]*u[3]+2*p[1]*u[7]*conj(u[6])+conj(p[1])*u[7]*u[2] + u[1]*conj(u[1])*u[5]+u[2]*conj(u[1])*u[4]+u[3]*conj(u[1])*u[3]+u[4]*conj(u[1])*u[2]+u[5]*conj(u[1])*u[1]+u[1]*conj(u[2])*u[6]+u[2]*conj(u[2])*u[5]+u[3]*conj(u[2])*u[4]+u[4]*conj(u[2])*u[3]+u[5]*conj(u[2])*u[2]+u[6]*conj(u[2])*u[1]+u[1]*conj(u[3])*u[7]+u[2]*conj(u[3])*u[6]+u[3]*conj(u[3])*u[5]+u[4]*conj(u[3])*u[4]+u[5]*conj(u[3])*u[3]+u[6]*conj(u[3])*u[2]+u[7]*conj(u[3])*u[1]+u[2]*conj(u[4])*u[7]+u[3]*conj(u[4])*u[6]+u[4]*conj(u[4])*u[5]+u[5]*conj(u[4])*u[4]+u[6]*conj(u[4])*u[3]+u[7]*conj(u[4])*u[2]+u[3]*conj(u[5])*u[7]+u[4]*conj(u[5])*u[6]+u[5]*conj(u[5])*u[5]+u[6]*conj(u[5])*u[4]+u[7]*conj(u[5])*u[3]+u[4]*conj(u[6])*u[7]+u[5]*conj(u[6])*u[6]+u[6]*conj(u[6])*u[5]+u[7]*conj(u[6])*u[4]+u[5]*conj(u[7])*u[7]+u[6]*conj(u[7])*u[6]+u[7]*conj(u[7])*u[5]))/im
A6=((4*((p[3])^2/2)-p[2]/2*abs2(p[1]))*u[6]-p[2]/2*(p[1])^2*conj(u[2]) - (p[2]/2)*(2*p[1]*u[3]*conj(u[1])+conj(p[1])*u[3]*u[7]+2*p[1]*u[4]*conj(u[2])+conj(p[1])*u[4]*u[6]+2*p[1]*u[5]*conj(u[3])+conj(p[1])*u[5]*u[5]+2*p[1]*u[6]*conj(u[4])+conj(p[1])*u[6]*u[4]+2*p[1]*u[7]*conj(u[5])+conj(p[1])*u[7]*u[3] + u[1]*conj(u[1])*u[6]+u[2]*conj(u[1])*u[5]+u[3]*conj(u[1])*u[4]+u[4]*conj(u[1])*u[3]+u[5]*conj(u[1])*u[2]+u[6]*conj(u[1])*u[1]+u[1]*conj(u[2])*u[7]+u[2]*conj(u[2])*u[6]+u[3]*conj(u[2])*u[5]+u[4]*conj(u[2])*u[4]+u[5]*conj(u[2])*u[3]+u[6]*conj(u[2])*u[2]+u[7]*conj(u[2])*u[1]+u[2]*conj(u[3])*u[7]+u[3]*conj(u[3])*u[6]+u[4]*conj(u[3])*u[5]+u[5]*conj(u[3])*u[4]+u[6]*conj(u[3])*u[3]+u[7]*conj(u[3])*u[2]+u[3]*conj(u[4])*u[7]+u[4]*conj(u[4])*u[6]+u[5]*conj(u[4])*u[5]+u[6]*conj(u[4])*u[4]+u[7]*conj(u[4])*u[3]+u[4]*conj(u[5])*u[7]+u[5]*conj(u[5])*u[6]+u[6]*conj(u[5])*u[5]+u[7]*conj(u[5])*u[4]+u[5]*conj(u[6])*u[7]+u[6]*conj(u[6])*u[6]+u[7]*conj(u[6])*u[5]+u[6]*conj(u[7])*u[7]+u[7]*conj(u[7])*u[6]))/im
A7=((9*((p[3])^2/2)-p[2]/2*abs2(p[1]))*u[7]-p[2]/2*(p[1])^2*conj(u[1]) - (p[2]/2)*(2*p[1]*u[4]*conj(u[1])+conj(p[1])*u[4]*u[7]+2*p[1]*u[5]*conj(u[2])+conj(p[1])*u[5]*u[6]+2*p[1]*u[6]*conj(u[3])+conj(p[1])*u[6]*u[5]+2*p[1]*u[7]*conj(u[4])+conj(p[1])*u[7]*u[4] + u[1]*conj(u[1])*u[7]+u[2]*conj(u[1])*u[6]+u[3]*conj(u[1])*u[5]+u[4]*conj(u[1])*u[4]+u[5]*conj(u[1])*u[3]+u[6]*conj(u[1])*u[2]+u[7]*conj(u[1])*u[1]+u[2]*conj(u[2])*u[7]+u[3]*conj(u[2])*u[6]+u[4]*conj(u[2])*u[5]+u[5]*conj(u[2])*u[4]+u[6]*conj(u[2])*u[3]+u[7]*conj(u[2])*u[2]+u[3]*conj(u[3])*u[7]+u[4]*conj(u[3])*u[6]+u[5]*conj(u[3])*u[5]+u[6]*conj(u[3])*u[4]+u[7]*conj(u[3])*u[3]+u[4]*conj(u[4])*u[7]+u[5]*conj(u[4])*u[6]+u[6]*conj(u[4])*u[5]+u[7]*conj(u[4])*u[4]+u[5]*conj(u[5])*u[7]+u[6]*conj(u[5])*u[6]+u[7]*conj(u[5])*u[5]+u[6]*conj(u[6])*u[7]+u[7]*conj(u[6])*u[6]+u[7]*conj(u[7])*u[7]))/im
@SVector [A1, A2, A3, A4, A5, A6, A7]
end
function Diff(N::Int64,psi::Float64,g::Float64)
println(tspan)
Analytic_N=Int64(floor((N-1)/2))
k=sqrt(g*psi) #defining parameters
p = [psi,g,k]
u1=[0.0+0.0*im for i in 1:N] # simple initial conditions
u1[Int64((N-1)/2)]=1.0+0.0*im
u1[Int64((N-1)/2)+1]=0.0+0.0*im
u1[Int64((N-1)/2)+2]=1.0+0.0*im
u0=SVector{N}(u1)
path=pwd()
path1="/path/JLD2 for g=$g, psi=$psi, N=$Analytic_N,t=$t_name"
rm(path1,force=true,recursive=true)
mkdir(path1)
cd(path1)
prob = ODEProblem(Nodes,u0,tspan,p)
sol = solve(prob,RK4(), adaptive=true,dense=false,calck=false)
@save "sol_tf=$(tspan[2]).jld2" sol
cd(path)
return "success"
end
end
Benchmarking this gives:
@benchmark Dif.Diff($7,$1.0,$16.0)
BenchmarkTools.Trial:
memory estimate: 2.53 GiB
allocs estimate: 58532496
--------------
minimum time: 1.857 s (11.35% GC)
median time: 1.867 s (11.93% GC)
mean time: 1.876 s (11.73% GC)
maximum time: 1.903 s (11.90% GC)
--------------
samples: 3
evals/sample: 1
Is there anything else I can do to improve the memory allocation? Perhaps I could use a better-suited solver? Write down equations in matrix representation (I would expect this would slow the program down)?
When I run the above for the system of 13 equations, memory allocation jump to 26.06GB. I have also tried using saveat=0.01
but found that it increases memory allocation for adaptive=true
, not by much, but somewhat. If I run this code for system of 41 equations, I get out of memory error.
My supervisor solves very similar equations using RK4 on C++, and is able to solve system of 1000 equations without any issues in around 30 minutes, which makes me conclude that I am doing something wrong.
Any advice is greatly appreciated!
Thank you.