I’m trying to make my own array (wrapper) type that can be used both on CPUs and GPUs. Here’s a minimal example:
struct MyWrapper{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
# Basic AbstractArray methods:
for f ∈ (:size, :getindex, :setindex!)
@eval Base.$f(x::MyWrapper, args...) = $f(, args...)
# Broadcasting:
# (based on the ArrayAndChar example at )
import Base.Broadcast: ArrayStyle, Broadcasted
Base.BroadcastStyle(::Type{<:MyWrapper}) = ArrayStyle{MyWrapper}()
function Base.similar(bc::Broadcasted{ArrayStyle{MyWrapper}}, ::Type{T}) where {T}
x = find_mw(bc)
MyWrapper(similar(, T, axes(bc)))
find_mw(bc::Broadcasted) = find_mw(bc.args)
find_mw(args::Tuple) = find_mw(find_mw(args[1]), Base.tail(args))
find_mw(x) = x
find_mw(a::MyWrapper, rest) = a
find_mw(::Any, rest) = find_mw(rest)
The issue
Now broadcasting of MyWrapper
seems to work on my CPU (although with an additional allocation and somewhat slower than the “unwrapped” array…)
using BenchmarkTools
x = rand(1000000)
y = MyWrapper(x)
@btime 1 .* $x # 712.590 μs (2 allocations: 7.63 MiB)
@btime 1 .* $y # 780.263 μs (3 allocations: 7.63 MiB)
However, if I attempt to wrap a CuArray
and do the same on my GPU, I get
using CuArrays
x = cu(rand(1000000))
y = MyWrapper(x)
@btime 1 .* $x # 5.639 μs (65 allocations: 2.27 KiB)
@btime 1 .* $y # 7.387 s (2000009 allocations: 183.11 MiB)
which doesn’t look good for MyWrapper
I suspect I’m missing something “trivial” (like overloading copyto!
for MyWrapper
?), but haven’t really been able to figure out exactly what. Suggestions?
Additional information (although I doubt that matters):
julia> versioninfo()
Julia Version 1.3.1
Commit 2d5741174c (2019-12-30 21:36 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
LIBM: libopenlibm
LLVM: libLLVM-6.0.1 (ORCJIT, skylake)
(v1.3) pkg> st CuArrays
Status `~/.julia/environments/v1.3/Project.toml`
[79e6a3ab] Adapt v1.0.0
[fa961155] CEnum v0.2.0
[3895d2a7] CUDAapi v2.1.0
[c5f51814] CUDAdrv v5.0.1
[be33ccc6] CUDAnative v2.7.0
[3a865a2d] CuArrays v1.6.0
[864edb3b] DataStructures v0.17.6
[872c559c] NNlib v0.6.2