How to speed up this ODEs solving code?

I am trying to solve a matrix ODE as:

using StaticArrays
using LinearOperators
using BenchmarkTools
using OrdinaryDiffEq
using LinearAlgebra
using Roots;

si=SMatrix{2, 2, Complex{Float64}}([1 0;0 1]);
sx=SMatrix{2, 2, Complex{Float64}}([0 1;1 0]);
sy=SMatrix{2, 2, Complex{Float64}}([0 -1im;1im 0]);
sz=SMatrix{2, 2, Complex{Float64}}([1 0;0 -1]);
sm=SMatrix{2, 2, Complex{Float64}}([0 1;0 0]);
sp=SMatrix{2, 2, Complex{Float64}}([0 0;1 0]);
a=SMatrix{4, 4, Complex{Float64}}([0 1 0 0
                                    0 0 sqrt(2.0) 0
                                    0 0 0 sqrt(3.0);
                                    0 0 0 0]);
ap=SMatrix{4, 4, Complex{Float64}}([0 0 0 0
                                    1 0 0 0
                                    0 sqrt(2.0) 0 0;
                                    0 0 sqrt(3.0) 0]);
ai=SMatrix{4, 4, Complex{Float64}}([1 0 0 0
                                    0 1 0 0
                                    0 0 1 0;
                                    0 0 0 1]);

kai=kron(ap*a,si,si,si,si);
ksτz2=kron(ai,sz,si,si,si);
ksσz2=kron(ai,si,sz,si,si);
ksτz1=kron(ai,si,si,sz,si);
ksσz1=kron(ai,si,si,si,sz);
khiτ2=kron(a+ap,sx,si,si,si);
khiτ1=kron(a+ap,si,si,sx,si);
khiσ2=kron(a+ap,sz,sx,si,si);
khiσ1=kron(a+ap,si,si,sz,sx);
khd2x=kron(ai,sx,sx,si,si);
khd2y=kron(ai,sy,sy,si,si);
khd1x=kron(ai,si,si,sx,sx);
khd1y=kron(ai,si,si,sy,sy);

Eσ(Bz::Float64,tc::Float64,Δ::Float64)=0.5*(Bz-2*tc)*sqrt(1+Δ^2/(Bz-2.0*tc)^2)+0.5*(Bz+2*tc)*sqrt(1+Δ^2/(Bz+2.0*tc)^2)
Eτ(Bz::Float64,tc::Float64,Δ::Float64)=-0.5*(Bz-2*tc)*sqrt(1+Δ^2/(Bz-2.0*tc)^2)+0.5*(Bz+2*tc)*sqrt(1+Δ^2/(Bz+2.0*tc)^2)
ϕp(Bz::Float64,tc::Float64,Δ::Float64)=atan(Δ/(2.0*tc+Bz))
ϕm(Bz::Float64,tc::Float64,Δ::Float64)=atan(Δ/(2.0*tc-Bz))
gσ(Bz::Float64,tc::Float64,Δ::Float64,gc::Float64)=gc*sin(0.5*(atan(Δ/(2.0*tc+Bz))+atan(Δ/(2.0*tc-Bz))))
gτ(Bz::Float64,tc::Float64,Δ::Float64,gc::Float64)=gc*cos(0.5*(atan(Δ/(2.0*tc+Bz))+atan(Δ/(2.0*tc-Bz))))
dσz(Bz::Float64,tc::Float64,Δ::Float64,dtc::Float64)=-0.5*dtc*(cos(ϕp(Bz,tc,Δ)-ϕm(Bz,tc,Δ)))
dτz(Bz::Float64,tc::Float64,Δ::Float64,dtc::Float64)=-0.5*dtc*(cos(ϕp(Bz,tc,Δ)+ϕm(Bz,tc,Δ)))
dx(Bz::Float64,tc::Float64,Δ::Float64,dtc::Float64)=0.5*dtc*(sin(ϕp(Bz,tc,Δ))+sin(ϕm(Bz,tc,Δ)))
dy(Bz::Float64,tc::Float64,Δ::Float64,dtc::Float64)=0.5*dtc*(sin(ϕm(Bz,tc,Δ))-sin(ϕp(Bz,tc,Δ)))
Ω(Bz::Float64,tc::Float64,Δ::Float64,dtc::Float64)=0.5*dtc*(sin(ϕm(Bz,tc,Δ))-sin(ϕp(Bz,tc,Δ)))+0.5*dtc*(sin(ϕp(Bz,tc,Δ))+sin(ϕm(Bz,tc,Δ)))
function findtc(Bz::Float64,Δ::Float64,ωc::Float64)
findtca= (x ->Eτ(Bz,x,Δ)-ωc);
  find_zero(findtca, 19.0)
end

function findωd2(Bz2::Float64,tc2::Float64,Δ2::Float64,dtc2::Float64,Bz1::Float64,tc1::Float64,Δ1::Float64,dtc1::Float64,gc::Float64,ωc::Float64,k::Float64,ωd2::Float64)
    ωc-Eσ(Bz2,tc2,Δ2)-sqrt(2)*gτ(Bz2,tc2,Δ2,gc)-ωd2-Ω(Bz1,tc1,Δ1,dtc1)*k/4.0
end

H(Eτ1::Float64, Eτ2::Float64,Eσ1::Float64,Eσ2::Float64,gτ1::Float64,gτ2::Float64,gσ1::Float64,gσ2::Float64, dσz1::Float64, dσz2::Float64, dτz1::Float64, dτz2::Float64, dx1::Float64, dx2::Float64, dy1::Float64, dy2::Float64,ωd1::Float64,ωd2::Float64,ωc::Float64,t::Float64)= -0.5*Eτ2*ksτz2-0.5*Eτ1*ksτz1-0.5*Eσ2*ksσz2-0.5*Eσ1*ksσz1+ωc*kai-gτ2*khiτ2-gτ1*khiτ1-gσ2*khiσ2-gσ1*khiσ1+ dσz2*cos(ωd2*t)*ksσz2+ dσz1*cos(ωd1*t)*ksσz1+ dτz2*cos(ωd2*t)*ksτz2+ dτz1*cos(ωd1*t)*ksτz1+dx2*cos(ωd2*t)*khd2x+dy2*cos(ωd2*t)*khd2y+dx1*cos(ωd1*t)*khd1x+dy1*cos(ωd1*t)*khd1y


function f(du,u,args,t)
    Eτ1, Eτ2,Eσ1,Eσ2,gτ1,gτ2,gσ1,gσ2, dσz1, dσz2, dτz1, dτz2, dx1, dx2, dy1, dy2,ωd1,ωd2,ωc=args;
   du.=-1im*H(Eτ1, Eτ2,Eσ1,Eσ2,gτ1,gτ2,gσ1,gσ2, dσz1, dσz2, dτz1, dτz2, dx1, dx2, dy1, dy2,ωd1,ωd2,ωc,t)*u
return du
end
function u(Bz2::Float64,Δ2::Float64,dtc2::Float64,Bz1::Float64,Δ1::Float64,dtc1::Float64,gc::Float64,ωc::Float64,k::Float64,T::Float64)
    tc1=findtc(Bz1,Δ1,ωc);
    tc2=findtc(Bz2,Δ2,ωc);
    Eτ1=Eτ(Bz1,tc1,Δ1);
    Eτ2=Eτ(Bz2,tc2,Δ1);
    Eσ1=Eτ(Bz1,tc1,Δ1);
    Eσ2=Eτ(Bz2,tc2,Δ1);
    gτ1=gτ(Bz1,tc1,Δ1,gc);
    gτ2=gτ(Bz2,tc2,Δ2,gc);
    gσ1=gτ(Bz1,tc1,Δ1,gc);
    gσ2=gτ(Bz2,tc2,Δ2,gc);
    dσz2=dσz(Bz2,tc2,Δ2,dtc2);
    dσz1=dσz(Bz1,tc1,Δ1,dtc1);
    dτz2=dτz(Bz2,tc2,Δ2,dtc2);
    dτz1=dτz(Bz1,tc1,Δ1,dtc1);
    dx2=dx(Bz2,tc2,Δ2,dtc2);
    dx1=dx(Bz1,tc1,Δ1,dtc1);
   dy2= dy(Bz2,tc2,Δ2,dtc2);
   dy1= dy(Bz1,tc1,Δ1,dtc1);
    ωd1=ωc-Eσ1-sqrt(2)* gτ1-Ω(Bz1,tc1,Δ1,dtc1)*k/4.0;
    
findωd2a= (x ->findωd2(Bz2,tc2,Δ2,dtc2,Bz1,tc1,Δ1,dtc1,gc,ωc,k,x));
ωd2=find_zero(findωd2a, 19.0);
    
u0=Matrix{Complex{Float64}}(I,64,64);
tspan=(0.0,T);
  
args=( Eτ1, Eτ2,Eσ1,Eσ2,gτ1,gτ2,gσ1,gσ2, dσz1, dσz2, dτz1, dτz2, dx1, dx2, dy1, dy2,ωd1,ωd2,ωc);

  
 prob=ODEProblem(f!,u0,tspan,args);

 sol=solve(prob,Tsit5(),saveat=[T]);
return sol
end


but it takes a long time.

Please share a flame graph.

This is going to allocate a lot.

Those type annotations don’t do anything for performance.

Have you profiled your code? It looks like there’s just a lot of random shooting in the dark here and not a lot of productive optimizing. I recommend watching this video:

And if you want the nitty gritty details:

https://book.sciml.ai/notes/02-Optimizing_Serial_Code/

2 Likes

It takes a long time for just one calculation and can’t finish benchmarking. Thanks for your recommendations! I’ll learn them first.

Profile single calls to your f and optimize that first.

If your maxiters is that high, did you try a different solver? FBDF(autodiff=false)? Maxiters should rarely ever be increased like that, and normally it means you’re using a method that isn’t stable enough for your ODE

1 Like

Sorry for this useless and misleading setting. :joy:
It is just a 64x64 matrix differential equation.

I’m not sure how that’s related other than the fact that you may want to specialize on sparsity like is shown in the tutorial:

https://docs.sciml.ai/DiffEqDocs/stable/tutorials/advanced_ode_example/