DEDataArrays on GPU

Following this discussion I have a question about the possibility of using DEDataArrays on GPU. Currently, I can not convert a DEDataArray A simply as CuArray(A), since the CuArray function will omit all extra fields of A apart of A.x. Also, defining A as a DEDataArray with A.x being the CuArray does not help since all broadcast operations with A will result in slow scalar operations. In the following piece of code I try to illustrate the above considerations:

using CUDA
using DiffEqBase


mutable struct MyArray{T, N} <: DEDataArray{T, N}
    x :: Array{T, N}
    y :: Array{T, 1}
    z :: T
end

a = MyArray(ones((2,2)), ones(2), 1.0)
b = MyArray(zeros((2,2)), zeros(2), 0.0)
@show a.x, a.y, a.z   # (a.x, a.y, a.z) = ([1.0 1.0; 1.0 1.0], [1.0, 1.0], 1.0)
@show b.x, b.y, b.z   # (b.x, b.y, b.z) = ([0.0 0.0; 0.0 0.0], [0.0, 0.0], 0.0)

# Broadcasting copies only the x component of DEDataArrays:
@. b = a
@show b.x, b.y, b.z   # (b.x, b.y, b.z) = ([1.0 1.0; 1.0 1.0], [0.0, 0.0], 0.0)

# Copy the rest of the DEDataArrays components:
DiffEqBase.copy_fields!(b, a)
@show b.x, b.y, b.z   # (b.x, b.y, b.z) = ([1.0 1.0; 1.0 1.0], [1.0, 1.0], 1.0)


# ------------------------------------------------------------------------------
# CuArray omits the extra components of DEDataArrays:
agpu = CuArray(a)
@show typeof(agpu)   # typeof(agpu) = CuArray{Float64,2}


# ------------------------------------------------------------------------------
mutable struct MyCuArray{T, N} <: DEDataArray{T, N}
    x :: CuArray{T, N}
    y :: CuArray{T, 1}
    z :: T
end

a = MyCuArray(CUDA.ones(2), CUDA.ones(2), 1.0f0)
b = MyCuArray(CUDA.zeros(2), CUDA.zeros(2), 0.0f0)

@. b = a   # broadcasting results in slow scalar operations
DiffEqBase.copy_fields!(b, a)   # copy_fields! is still fast

Are there any simple workaround for this issue? Can it be solved by redefining CuArray converter or broadcast styles for DEDataArrays? Or the only solution is to define new, let say, DEDataCuArrays which subtype AbstractGPUArrays?

It can be solved by doing a better broadcast style, yes.

It turned out that by redefining of the copyto! function it is possible to make the broadcasting from my previous example to work:

using CUDA
using DiffEqBase: DEDataArray, DEDataArrayStyle, unpack

CUDA.allowscalar(false)

Base.display(A::DEDataArray) = display(A.x)   # allows to show DEDataArray with wrapped CuArray in REPL

Base.copyto!(dest::DEDataArray, bc::Broadcast.Broadcasted{DEDataArrayStyle}) = copyto!(dest.x, unpack(bc))
Base.copyto!(dest::AbstractArray, bc::Broadcast.Broadcasted{DEDataArrayStyle}) = copyto!(dest, unpack(bc))


mutable struct MyCuArray{T, N} <: DEDataArray{T, N}
    x :: CuArray{T, N}
    y :: CuArray{T, 1}
    z :: T
end

A = MyCuArray(CUDA.ones((2,2)), CUDA.ones(2), 1f0)
B = MyCuArray(CUDA.zeros((2,2)), CUDA.zeros(2), 0f0)
x = CUDA.zeros((2,2))


@. B = 2 * A + A^2 + 1

C = @. 2 * A + A^2 + 1

@. x = 2 * A + A^2 + 1

# @. B = A + x   # scalar getindex is disallowed

# C = @. A + x   # scalar getindex is disallowed

# @. x = A + x   # passing and using non-bitstype argument

However, every time when in the right-hand side of the expression I have an ordinary CuArray, the solution breaks down. Can you give me any hint on where to dig? My only idea is that here the unpack function uses an explicit call to args[1]. In this case, would it be a good idea to check for CuArray type of argument before unpacking it, or there is a better solution?

I think it could be good to try and remove the args[1]. Open an issue in DiffEqBase on this.