Need help speeding up ode ensemble problem; preallocation and type instability?

Hi! I’m writing some particle tracing code with large ensemble sizes and it’s super slow. Because for my simulations I need really small timesteps, the solution is huge, and this takes a very very long time (like couple hundred hours :sweat_smile:) to get what I want (I’ve been mostly testing on small sample sizes to optimize performance). However, I have gc time at ~70% with massive amount of allocations and I’m just stuck. I did some investigative profiling to figure out what was going on, and this is what I see:

My equations of motion function (eom!) is what I’m trying to optimize, and there’s 4 “pillars” of gc fun it seems (pillars 1, 2, 4, and 5). I am a novice at Julia, and don’t really know how the DifferentialEquations.jl stuff works (I did try to look up cmd_gen and arg_gen in the source code, but it was beyond me), so please bear with me as I try my best to explain what’s going on. My takeaways from looking at this is that

  1. Is the append! and getindex function calls in pillars 1, 4, and 5 for appending to the solution array or something? Would that be necessary if we preallocated the solution vector somehow? I couldn’t figure out how to do that. I do use the save_at parameter and I know how long I’m running the integration for, so I should know exactly what size solution array to preallocate, right?

  2. Pillar 2 is intriguing to me. It seems there is some string that is generated, but I’m guessing it’s type instable due to the unsafe_string call? I have no idea where it’s coming from, and how it would be called. Perhaps it’s something within the DifferentialEquations.jl solver?

I feel like if I can get rid of pillars 1, 2, 4, and 5, I can get several orders of magnitude speed improvement and very low gc time (turning my hundreds of hours sim to hopefully minutes). At the same time, I have no idea what they do, so they’re probably important. I guess my question is are any of my ideas above valid, and if so how would I implement them?

I’m pasting relevant code for reference (not all of it because there’s a lot…you can find it all here).

How I call solve:

prob_func = (prob,i,repeat) -> remake(prob, u0 = H0, p = @SVector [eta, epsilon, Omegape, omegam])
ensemble_prob = EnsembleProblem(prob::ODEProblem, prob_func=prob_func)
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads(), save_everystep=false;
                            callback=CallbackSet(cb1, cb2), trajectories=nPerBatch,
                            dtmax=resolution, linear_solver=:LapackDense, maxiters=1e8, 
                            saveat = saveDecimation*resolution)

Equations of motion:

function eom!(dH,H,p::SVector{4, Float64},t::Float64)
    @inbounds begin
        # z, pz, zeta, mu, lambda = H
        # eta, epsilon, Omegape, omegam = p
        sinlambda = @fastmath @views sin(H[5]);
        coslambda = @fastmath @views cos(H[5]);
        sinzeta = @fastmath @views sin(H[3]);
        coszeta = @fastmath @views cos(H[3]);

        u = @fastmath @views .5*(tanh(H[5]/deg2rad(1))+1);
        b = @fastmath @views sqrt(1+3*sinlambda^2)/(coslambda^6);
        db = @fastmath @views (3*(27*sinlambda-5*sin(3*H[5])))/(coslambda^8*(4+12*sinlambda^2));
        gamma = @fastmath @views sqrt(1 + H[2]^2 + 2*H[4]*b);
        K = @fastmath @views (p[3] * (coslambda^(-5/2)))/sqrt(b/p[4] - 1);
        psi = @fastmath @views p[1]*p[2]*u*sqrt(2*H[4]*b)/gamma;

        dH1 = @fastmath @views H[2]/gamma;
        dH2 = @fastmath @views -(H[4]*db)/gamma - (psi*coszeta);
        dH3 = @fastmath @views p[1]*(K*dH1 - p[4] + b/gamma) + (psi*sinzeta)/(2*H[4]*K);
        dH4 = @fastmath @views -(psi*coszeta)/K;
        dH5 = @fastmath @views H[2]/(gamma*coslambda*sqrt(1+3*sinlambda^2)); 

        dH .= SizedVector{5}([ dH1, dH2, dH3, dH4, dH5 ]);
    end
end

Callbacks:

function palostcondition(H,t,integrator)
    # condition: if particle enters loss cone
    @inbounds begin
        b = @fastmath @views sqrt(1+3*sin(H[5])^2)/(cos(H[5])^6);
        gamma = @fastmath @views sqrt(1 + H[2]^2 + 2*H[4]*b);
        return  @fastmath @views (rad2deg(asin(sqrt((2*H[4])/(gamma^2 -1))))) < (lossConeAngle)
    end
end

function ixlostcondition(H,t,integrator)
    # condition: if I_x approaches 0
    @inbounds begin
        b = @fastmath @views sqrt(1+3*sin(H[5])^2)/(cos(H[5])^6);
        return @fastmath @views 2*H[4]*b < (1/saveDecimation)
    end
end

affect!(integrator) = terminate!(integrator); # terminate if condition reached
cb1 = DiscreteCallback(palostcondition,affect!);
cb2 = DiscreteCallback(ixlostcondition,affect!);

Really appreciate any support/help! Thank you!

Here probably you would like something like:

dH .= @SVector [ dH1, dH2, dH3, dH4, dH5 ]

Also, I suggest removing all those fastmath, inbounds, and most views (which only make sense for slices of arrays). Those flags should be a last resort.

2 Likes

Ok, just did that and benchmarks show virtually no change in time. It is more readable though, so thanks! By the way, here’s the latest results from the profile (the one in the OP is a little outdated).

For short test runs, I’m seeing:
57.224161 seconds (1.77 G allocations: 106.200 GiB, 69.77% gc time)

I feel like if I implement this correctly, shouldn’t I be aiming for like…zero allocations?

Yes, you are very likely allocating something in some inner loop. There were I mentioned above is one possible place (there is an allocation there of an intermediary 5-element vector). Assign those 5 elements of dH explicitly and see what happens.

If you can provide a running example people will be able to help you more.

Am I reading it correctly in that 75% of the time in that graph is in cmd_gen coming from nlwhistlers/helperFunctions.jl at 3fcebc3f74c850436f75d47d44d3fca9f908643c · ethantsai/nlwhistlers · GitHub ? What if you just comment out that? I assume it was meant as a comment anyway…

6 Likes

Oh my god, I thought that was just a comment!!! Big facepalm moment oops. That just sped it up by more than a whole order of magnitude thank you!!! I’m used to doing docstrings like that in python.

1 Like