I’ll add an inherits_from_parent
trait to ArrayInterface
that’ll let you define 3 simple methods:
ArrayInterface.inherits_from_parent
Base.parent
ArrayInterface.parent_type
which should add LoopVectorization support.
Currently, a few more methods are needed:
using ArrayInterface
struct ArrayWrapper{T,N,A <: AbstractArray{T,N}} <: AbstractArray{T,N}
data::A
end
ArrayInterface.parent_type(::Type{ArrayWrapper{T,N,A}}) where {T,N,A} = A
Base.parent(x::ArrayWrapper) = x.data
Base.unsafe_convert(::Type{Ptr{T}}, x::ArrayWrapper{T}) where {T} = Base.unsafe_convert(Ptr{T}, parent(x))
Base.size(x::ArrayWrapper) = size(parent(x))
Base.strides(x::ArrayWrapper) = strides(parent(x))
ArrayInterface.contiguous_axis(::Type{A}) where {A <: ArrayWrapper} = ArrayInterface.contiguous_axis(ArrayInterface.parent_type(A))
ArrayInterface.contiguous_batch_size(::Type{A}) where {A <: ArrayWrapper} = ArrayInterface.contiguous_batch_size(ArrayInterface.parent_type(A))
ArrayInterface.stride_rank(::Type{A}) where {A <: ArrayWrapper} = ArrayInterface.stride_rank(ArrayInterface.parent_type(A))
Note of course that you still need to implement the rest of the usual array interface, such as getindex
. Although now:
julia> A = ArrayWrapper(rand(8,10));
julia> using LoopVectorization
julia> stridedpointer(A) # works
VectorizationBase.StridedPointer{Float64, 2, 1, 0, (1, 2), Tuple{ArrayInterface.StaticInt{8}, Int64}, Tuple{ArrayInterface.StaticInt{1}, ArrayInterface.StaticInt{1}}}(Ptr{Float64} @0x00007f0854fd8610, (Static(8), 64), (Static(1), Static(1)))
julia> LoopVectorization.check_args(A) # true
true