Forcing CUArray for tType and IType in ODESolution

I’m trying to use OrdinaryDiffEq, Zygote, and CUDA to reverse-mode differentiate an ODE running on GPU. My code works on CPU, but hits this error on GPU:

GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceVector{Float32, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}}}, Int64) failed KernelError: passing and using non-bitstype argument Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits: .args is of type Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}} which is not isbits. .1 is of type Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}} which is not isbits. .x is of type Vector{Float32} which is not isbits.

The problem seems to be that the tType and IType in ODESolution are CPU Arrays rather than CUDA ones. This is typeof(solve(..)) - note the presence of Vector{Float32}:

ODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, Nothing, ODEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, true, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ODEFunction{true, typeof(ċȧ_primal!), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, SciMLBase.StandardODEProblem}, Tsit5, SciMLBase.SensitivityInterpolation{Vector{Float32}, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, DiffEqBase.DEStats}

How do I make those CUDA arrays? All the arrays passed to ODEProblem have been converted with gpu. I don’t need the interpolation sol.u(t), and I’m using BacksolveAdjoint to differentiate.

Share the code.

Will do - mind if we discuss the specific ODE offline first?