Is there any implementation for ODE solving using Metal.jl backends? I know there is DiffEqGPU, but the metal support is only for ensemble problems. I cannot for instance do
f(u,p,t) = A*u
tspan = (0.0f0, 1.0f0)
prob = ODEProblem(f, u0, tspan)
solve(prob, Tsit5())
or
solve(prob, GPUTsit5)
where A, u, are Metal arrays. Tsit5() copies back and forth to CPU, and GPUTsit5 is designed for the ensemble solver. Am I just going about this in a wrong way or would I need to write my own solving algorithm to do this?