I’m trying to use OrdinaryDiffEq, Zygote, and CUDA to reverse-mode differentiate an ODE running on GPU. My code works on CPU, but hits this error on GPU:
GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceVector{Float32, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}}}, Int64) failed KernelError: passing and using non-bitstype argument Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}}, typeof(identity), Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits: .args is of type Tuple{Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}}} which is not isbits. .1 is of type Base.Broadcast.Extruded{Vector{Float32}, Tuple{Bool}, Tuple{Int64}} which is not isbits. .x is of type Vector{Float32} which is not isbits.
The problem seems to be that the tType
and IType
in ODESolution
are CPU Arrays rather than CUDA ones. This is typeof(solve(..))
- note the presence of Vector{Float32}
:
ODESolution{Float32, 2, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, Nothing, ODEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, true, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ODEFunction{true, typeof(ċȧ_primal!), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Symbol, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Symbol}, NamedTuple{(:saveat,), Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, SciMLBase.StandardODEProblem}, Tsit5, SciMLBase.SensitivityInterpolation{Vector{Float32}, Vector{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, DiffEqBase.DEStats}
How do I make those CUDA arrays? All the arrays passed to ODEProblem
have been converted with gpu
. I don’t need the interpolation sol.u(t)
, and I’m using BacksolveAdjoint
to differentiate.