GPU problems with RecursiveArrayTools

Hi, I have an system of ODEs where the state is an array of arrays. I want to be able to solve my ODE on the gpu. I am using ArrayPartition where each array is an CuArray. However, I get the scalar indexing error. The code only works if I set CUDA.@allowscalar on the solve call which is not something that works for as I want to use AD later which fails using zygote given that try/catch is not supported.
Would appreciate your help!

A MWE:

using Random, DifferentialEquations, RecursiveArrayTools, Lux, ComponentArrays, LuxCUDA, CUDA,
CUDA.allowscalar(false)
rng = Random.default_rng()
A = cu(rand64(2,2)) ; B = cu(rand64(2,2))
u0 = ArrayPartition((A,B))
node_vf = Dense(2 => 2, tanh)
p, st = Lux.setup(rng, node_vf)
p = p |> ComponentArray |> dev
st = st |> dev
function dudt(u, p, t)
A, B = u.x
da = node_vf(A, p, st)[1]
db = node_vf(B, p, st)[1]
ArrayPartition((da, db))
end
prob = ODEProblem(dudt, u0, (0.0f0, 1.0f0), p)
sol = solve(prob, Tsit5(), sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP()))

I get

Scalar indexing is disallowed. Invocation of setindex! resulted in scalar indexing of a GPU array. This is typically caused by calling an iterating implementation of a method. Such implementations do not execute on the GPU, but very slowly on the CPU, and therefore are only permitted from the REPL for prototyping purposes. If you did intend to index this array, annotate the caller with @allowscalar. Stacktrace: [1] error(s::String) @ Base ./error.jl:35 [2] assertscalar(op::String) @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103 [3] setindex!(A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, v::Float32, I::Int64) @ GPUArrays ~/.julia/packages/GPUArrays/dAUOE/src/host/indexing.jl:56 [4] setindex! @ ~/.julia/packages/RecursiveArrayTools/7kB7n/src/array_partition.jl:246 [inlined] [5] fill!(A::ArrayPartition{Float32, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, x::Bool) @ Base ./multidimensional.jl:1114 [6] recursivefill! @ ~/.julia/packages/RecursiveArrayTools/7kB7n/src/utils.jl:138 [inlined] [7] alg_cache(alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, u::ArrayPartition{Float32, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, rate_prototype::ArrayPartition{Float32, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ::Type{Float32}, ::Type{Float32}, ::Type{Float32}, uprev::ArrayPartition{Float32, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, uprev2::ArrayPartition{Float32, Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, f::ODEFunction{true, SciMLBase.AutoSpecialize, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, t::Float32, dt::Float32, reltol::Float32, p::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))}}}, calck::Bool, ::Val{true}) @ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/pWCEq/src/caches/low_order_rk_caches.jl:535

I think I just responded to this somewhere? Try using Adapt.jl’s adapt, or open an issue as this may need some special handling.