Linear algebra with array wrapper types

What is the “correct” way to define multiplication and division for AbstractArray types that are wrappers around other arrays?

So let’s say I have an array wrapper type:

struct MyArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
    inner::A
end

I would like it to use BLAS when the inner array type is BLASable. I have Base.pointer and Base.unsafe_convert defined, but those only get called if my array type is a DenseArray. I could define MyArray as a subtype of DenseArray, but then I wouldn’t be able to hold, say, StaticArrays or other non-dense arrays since multiplication would dispatch in a way that isn’t compatible with the inner array.

Another option would be to just forward all applicable methods to the inner array. But this leads to method ambiguity hell because LinearAlgebra is so heavily specialized for different array types. Not to mention this doesn’t play nicely with other custom array types that have their own *, /, and \ defined (like StaticArrays). You end up having to include every package that your array could possibly operate with and define specializations for everything.

I feel like there is something I’m missing here. I’ve looked through other array wrapper packages and they seem to either settle for the slower non-BLAS fallback (even when they are stored contiguously in memory), subtype DenseArray, make separate types for different inner types (I see this especially with StaticArray inner types), or overload/forward their multiplications and divisions to their inner arrays.

I feel like there should be a trait-based option or something to make this easier, but I don’t see anything like that in LinearAlgebra.

I think this is the way to go — when method ambiguities arise, you just fix them.

But it is hard to say anything more concrete without seeing some more code. The key question is what your wrapper type will do differently (in the sense that eg a LowerTriangular wrapper will treat elements above the diagonal as zero).

I’m mostly interested in types that have some functionality that’s orthogonal to linear algebra but still need to be used in a linear algebra context. So nothing structural that would warrant custom algorithms. Just user-facing wrapperish things like fancy indexing.

ComponentArrays.jl is my specific use case. In older versions, I tried the shotgun blast approach of overloading everything that LinearAlgebra specialized on. And this worked fine until someone tried to multiply a SMatrix from StaticArrays with a ComponentVector. I could resolve the ambiguities by patching it with a Requires file on my end, but then I would have to do this with every array type that has custom multiplication/division defined. And worse, anyone who writes an array package that wants to work with ComponentArrays has to add in their own specializations for ComponentArrays. It spreads like a virus. The temporary solution I have is to subtype DenseArray, which gives me all the fast linear algebra without having to overload anything, but it comes at the cost of not being able to have things like StaticArrays as an inner array type.

For me, at least, I think the answer is going to eventually be to have DenseComponentArray and NonDenseComponentArray types under the hood and make ComponentArray be the Union of the two. Right now this causes the compiler to give up on constant propagation for my indexing (which is already at the edge of what it’s is willing to tolerate). But I’m sure I’ll be able to get there.

I’m also wondering if it would be helpful going forward to introduce a Dense trait in Base or LinearAlgebra that could deal with this kind of stuff instead of an abstract type.

I think that the main reason LinearAlgebra doesn’t have traits is that it is one of the oldest parts of Julia, and when traits were introduced and became more widely used, the cost of breaking the API was already too high.

As you have noticed, wrappers can be somewhat brittle and make it hard to organize code. It would be great to see some experimentation with a traits-based API for linear algebra.

4 Likes

There was https://github.com/JuliaLang/julia/pull/25558 parts of which became GitHub - JuliaLinearAlgebra/ArrayLayouts.jl: A Julia package for describing array layouts and more general fast linear algebra .

There is a much lighter proposal in https://github.com/JuliaLang/julia/pull/30432 to make strides return nothing when it doesn’t make sense.

I was working on Allow batched_mul to work through PermutedDimsArray, II by mcabbott · Pull Request #191 · FluxML/NNlib.jl · GitHub to make batched_mul work through with various wrappers, without dispatch hell.

3 Likes