Is there any implementation for ODE solving using Metal.jl backends? I know there is DiffEqGPU, but the metal support is only for ensemble problems. I cannot for instance do
where A, u, are Metal arrays. Tsit5() copies back and forth to CPU, and GPUTsit5 is designed for the ensemble solver. Am I just going about this in a wrong way or would I need to write my own solving algorithm to do this?
If I set up the same problem, with f(u,p,t) = A*u as the function, the number of allocations when A, u are MtlArrays is almost 50x higher than the allocations when both are Float32 arrays (and the MtlArray takes significantly longer than the Float32 arrays). My assumption is that there are steps in the Tsit5 algorithm which aren’t handled properly in Metal, and maybe the data is getting saved in non Metal arrays and is getting copied back and forth? I’m not sure what other reason there would be for such high memory allocations. It is not just Tsit5 either, the vern solvers also have the same problem and I’m guessing that all the other solvers will as well.
I could be totally wrong, and the issue is elsewhere. I just don’t have any idea of what I would be.
I just gave that a try and it still has very large allocations and is orders of magnitude slower.
Also no, I am testing with smaller matrices, the largest I did was a couple ~4000x4000 (it was becoming marginally equivalent). Is there a way to speed up the MtlArray operations or is this some optimization that needs to happen in Metal.jl or on apples end before it is fast for more intermediate scale problems?
Though even if the operations themselves are slower, I’m still a bit confused by the whole memory allocations thing. I feel like it should be comparable.
Ah, interesting. I have seen that before but I wasn’t sure how much that would translate into ode solving.
I did try the in-place, the number of allocations went down, but nowhere near as much as it did for the CPU version. In fact, the gap in allocations got worse actually, ~120x more for the MtlArray.
Oh, uh, I have just been using @benchmark. I’m not very familiar with the memory allocation profiler. That might take some time for me to figure out how to interpret.