Hi! I’m writing some particle tracing code with large ensemble sizes and it’s super slow. Because for my simulations I need really small timesteps, the solution is huge, and this takes a very very long time (like couple hundred hours ) to get what I want (I’ve been mostly testing on small sample sizes to optimize performance). However, I have gc time at ~70% with massive amount of allocations and I’m just stuck. I did some investigative profiling to figure out what was going on, and this is what I see:
My equations of motion function (eom!) is what I’m trying to optimize, and there’s 4 “pillars” of gc fun it seems (pillars 1, 2, 4, and 5). I am a novice at Julia, and don’t really know how the DifferentialEquations.jl stuff works (I did try to look up cmd_gen and arg_gen in the source code, but it was beyond me), so please bear with me as I try my best to explain what’s going on. My takeaways from looking at this is that
-
Is the
append!
andgetindex
function calls in pillars 1, 4, and 5 for appending to the solution array or something? Would that be necessary if we preallocated the solution vector somehow? I couldn’t figure out how to do that. I do use the save_at parameter and I know how long I’m running the integration for, so I should know exactly what size solution array to preallocate, right? -
Pillar 2 is intriguing to me. It seems there is some string that is generated, but I’m guessing it’s type instable due to the
unsafe_string
call? I have no idea where it’s coming from, and how it would be called. Perhaps it’s something within the DifferentialEquations.jl solver?
I feel like if I can get rid of pillars 1, 2, 4, and 5, I can get several orders of magnitude speed improvement and very low gc time (turning my hundreds of hours sim to hopefully minutes). At the same time, I have no idea what they do, so they’re probably important. I guess my question is are any of my ideas above valid, and if so how would I implement them?
I’m pasting relevant code for reference (not all of it because there’s a lot…you can find it all here).
How I call solve
:
prob_func = (prob,i,repeat) -> remake(prob, u0 = H0, p = @SVector [eta, epsilon, Omegape, omegam])
ensemble_prob = EnsembleProblem(prob::ODEProblem, prob_func=prob_func)
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads(), save_everystep=false;
callback=CallbackSet(cb1, cb2), trajectories=nPerBatch,
dtmax=resolution, linear_solver=:LapackDense, maxiters=1e8,
saveat = saveDecimation*resolution)
Equations of motion:
function eom!(dH,H,p::SVector{4, Float64},t::Float64)
@inbounds begin
# z, pz, zeta, mu, lambda = H
# eta, epsilon, Omegape, omegam = p
sinlambda = @fastmath @views sin(H[5]);
coslambda = @fastmath @views cos(H[5]);
sinzeta = @fastmath @views sin(H[3]);
coszeta = @fastmath @views cos(H[3]);
u = @fastmath @views .5*(tanh(H[5]/deg2rad(1))+1);
b = @fastmath @views sqrt(1+3*sinlambda^2)/(coslambda^6);
db = @fastmath @views (3*(27*sinlambda-5*sin(3*H[5])))/(coslambda^8*(4+12*sinlambda^2));
gamma = @fastmath @views sqrt(1 + H[2]^2 + 2*H[4]*b);
K = @fastmath @views (p[3] * (coslambda^(-5/2)))/sqrt(b/p[4] - 1);
psi = @fastmath @views p[1]*p[2]*u*sqrt(2*H[4]*b)/gamma;
dH1 = @fastmath @views H[2]/gamma;
dH2 = @fastmath @views -(H[4]*db)/gamma - (psi*coszeta);
dH3 = @fastmath @views p[1]*(K*dH1 - p[4] + b/gamma) + (psi*sinzeta)/(2*H[4]*K);
dH4 = @fastmath @views -(psi*coszeta)/K;
dH5 = @fastmath @views H[2]/(gamma*coslambda*sqrt(1+3*sinlambda^2));
dH .= SizedVector{5}([ dH1, dH2, dH3, dH4, dH5 ]);
end
end
Callbacks:
function palostcondition(H,t,integrator)
# condition: if particle enters loss cone
@inbounds begin
b = @fastmath @views sqrt(1+3*sin(H[5])^2)/(cos(H[5])^6);
gamma = @fastmath @views sqrt(1 + H[2]^2 + 2*H[4]*b);
return @fastmath @views (rad2deg(asin(sqrt((2*H[4])/(gamma^2 -1))))) < (lossConeAngle)
end
end
function ixlostcondition(H,t,integrator)
# condition: if I_x approaches 0
@inbounds begin
b = @fastmath @views sqrt(1+3*sin(H[5])^2)/(cos(H[5])^6);
return @fastmath @views 2*H[4]*b < (1/saveDecimation)
end
end
affect!(integrator) = terminate!(integrator); # terminate if condition reached
cb1 = DiscreteCallback(palostcondition,affect!);
cb2 = DiscreteCallback(ixlostcondition,affect!);
Really appreciate any support/help! Thank you!