Large ODE Solver for Metal.jl

Is there any implementation for ODE solving using Metal.jl backends? I know there is DiffEqGPU, but the metal support is only for ensemble problems. I cannot for instance do

f(u,p,t) = A*u
tspan = (0.0f0, 1.0f0)
prob = ODEProblem(f, u0, tspan)
solve(prob, Tsit5())

or

solve(prob, GPUTsit5)

where A, u, are Metal arrays. Tsit5() copies back and forth to CPU, and GPUTsit5 is designed for the ensemble solver. Am I just going about this in a wrong way or would I need to write my own solving algorithm to do this?

1 Like

It does not. Where did you get that?

If I set up the same problem, with f(u,p,t) = A*u as the function, the number of allocations when A, u are MtlArrays is almost 50x higher than the allocations when both are Float32 arrays (and the MtlArray takes significantly longer than the Float32 arrays). My assumption is that there are steps in the Tsit5 algorithm which aren’t handled properly in Metal, and maybe the data is getting saved in non Metal arrays and is getting copied back and forth? I’m not sure what other reason there would be for such high memory allocations. It is not just Tsit5 either, the vern solvers also have the same problem and I’m guessing that all the other solvers will as well.

I could be totally wrong, and the issue is elsewhere. I just don’t have any idea of what I would be.

What about f(du,u,p,t) = mul!(du,A,u)? You’re using the allocating path instead of the non-allocating path.

MtlArray operations are generally pretty slow until things get large. Are you testing 50,000x50,000 matrices?

I just gave that a try and it still has very large allocations and is orders of magnitude slower.
Also no, I am testing with smaller matrices, the largest I did was a couple ~4000x4000 (it was becoming marginally equivalent). Is there a way to speed up the MtlArray operations or is this some optimization that needs to happen in Metal.jl or on apples end before it is fast for more intermediate scale problems?

Though even if the operations themselves are slower, I’m still a bit confused by the whole memory allocations thing. I feel like it should be comparable.

Metal operations are much slower at that size. See for example:

It should be, it’s worth looking into. But first make it in-place like I showed, and then it should be not allocating in the steps.

Ah, interesting. I have seen that before but I wasn’t sure how much that would translate into ode solving.

I did try the in-place, the number of allocations went down, but nowhere near as much as it did for the CPU version. In fact, the gap in allocations got worse actually, ~120x more for the MtlArray.

What does the allocations profiler say the source is?

Oh, uh, I have just been using @benchmark. I’m not very familiar with the memory allocation profiler. That might take some time for me to figure out how to interpret.


Is this the right thing to look at? (following [Profiling · The Julia Language] using PProf.)

Use the VS Code profiler to get it into a flamegraph? Usually I find that easier to find the lines of code that it can be attributed to.