Mul! allocates for custom struct which is a subtype of AbstractMatrix

I am trying to write an efficient matricization function following Kolda’s and Bader’s work on “Tensor Decompositions and Applications”. The struct I have is non-allocating when initialized as it is basically a lazy view of the array provided. Here is my implementation.

# Inserts the mode‑n index into a tuple of non‑n indices.
@inline function _insert_index(nonmode_idx::NTuple{M,Int}, i::Int, n::Int) where {M}
    N = M + 1
    return @inbounds ntuple(k -> k == n ? i : (k < n ? nonmode_idx[k] : nonmode_idx[k-1]), N)
end

@inline function _tuple_prod(nonmode_dims::NTuple{M,Int}, n::Int) where {M}
    s = 1
    @inbounds for i in 1:n
        s *= nonmode_dims[i]
    end
    return s
end

# --- Definition of ModeNMatrix ---

"""
    struct ModeNMatrix{T, N, M, A<:AbstractArray{T,N}} <: AbstractMatrix{T}

A lazy view that matricizes an N‑dimensional tensor `X` along mode `n`
without copying data. The resulting matrix has size
  (dims[n], (∏dims)/dims[n])
where `dims` are the dimensions of the original tensor.
The fields `nonmode` and `col_strides` are stored as NTuples (length M = N-1)
to avoid heap allocations.
"""
struct ModeNMatrix{T, N, M, A<:AbstractArray{T,N}} <: AbstractMatrix{T}
    X::A                         # The original tensor.
    n::Int                       # The mode along which we matricize.
    dims::NTuple{N,Int}          # The dimensions of the original tensor.
    nonmode::NTuple{M,Int}       # The modes other than n.
    col_strides::NTuple{M,Int}   # Precomputed strides for the non‑n modes.
end

"""
    ModeNMatrix(X::AbstractArray{T,N}, n::Int)

Construct a lazy matricized view of the N‑dimensional tensor `X` along mode `n`.
This computes the non‑n mode ordering and column strides as NTuples.
"""
function ModeNMatrix(X::AbstractArray{T,N}, n::Int) where {T, N}
    @assert 0 ≤ n ≤ N "Mode n must be between 0 and N."
    n = n + 1
    M = N - 1  # number of modes other than n
    dims = size(X)  # dims is a NTuple{N,Int}
    # The natural ordering for non‑n modes: (1, 2, …, n-1, n+1, …, N)
    nonmode = ntuple(i -> i < n ? i : i + 1, M)
    # Extract the dimensions for the non‑n modes.
    nonmode_dims = @views @inbounds ntuple(i -> dims[i < n ? i : i + 1], M)
    # Compute static column strides:
    # For the 1st non‑n mode, stride = 1; for k ≥ 2, stride[k] = prod(nonmode_dims[1:k-1]).
    col_strides = @inbounds ntuple(i -> i == 1 ? 1 : _tuple_prod(nonmode_dims, i - 1), M)# prod(@views nonmode_dims[1:i-1]), M)
    return ModeNMatrix{T, N, M, typeof(X)}(X, n, dims, nonmode, col_strides)
end

# Convenience function.
matricize(X::AbstractArray, n::Int) = ModeNMatrix(X, n)

# size: rows = dims[n], columns = (∏dims)/dims[n]
Base.size(MM::ModeNMatrix{T, N, M, A}) where {T, N, M, A} =
    (MM.dims[MM.n], div(Base.prod(MM.dims), MM.dims[MM.n]))

# axes: we define the index ranges for rows and columns.
Base.axes(MM::ModeNMatrix) = (Base.OneTo(size(MM,1)), Base.OneTo(size(MM,2)))

# Let’s declare the preferred index style.
Base.IndexStyle(::Type{<:ModeNMatrix}) = IndexLinear()

# Primary getindex: when given two integer indices.
function Base.getindex(MM::ModeNMatrix{T, N, M, A}, i::Int, j::Int) where {T, N, M, A}
    j_adj = j - 1  # convert j to 0-based indexing for the non‑n modes
    # Compute the indices for each non‑n mode (stored in a NTuple).
    nonmode_idx = @inbounds ntuple(k -> (j_adj ÷ MM.col_strides[k]) % MM.dims[MM.nonmode[k]] + 1, M)
    # Insert the mode‑n index into its proper place.
    full_idx = _insert_index(nonmode_idx, i, MM.n)
    return MM.X[full_idx...]
end

function Base.setindex!(MM::ModeNMatrix{T, N, M, A}, value::T, i::Int, j::Int) where {T, N, M, A}
    j_adj = j - 1  # Convert column index to 0-based for stride computation

    # Compute the non-mode indices using the column index
    nonmode_idx = ntuple(k -> (j_adj ÷ MM.col_strides[k]) % MM.dims[MM.nonmode[k]] + 1, M)

    # Reconstruct the full index for the original tensor
    full_idx = ntuple(k -> k == MM.n ? i : (k < MM.n ? nonmode_idx[k] : nonmode_idx[k-1]), N)

    # Assign the value to the original tensor
    @inbounds MM.X[full_idx...] = value
end

function Base.setindex!(MM::ModeNMatrix, value, i::Int)
    ci = CartesianIndices(MM)[i]
    MM[ci[1], ci[2]] = value
end

# Define getindex for CartesianIndex{2} (delegates to the two-integer version).
Base.getindex(MM::ModeNMatrix, I::CartesianIndex{2}) = MM[I[1], I[2]]

# Also define a fallback for tuple indexing.
function Base.getindex(MM::ModeNMatrix, inds::Tuple{Vararg{Int}})
    if length(inds) == 2
        return MM[inds[1], inds[2]]
    else
        return MM[CartesianIndex(inds)]
    end
end

# Finally, define linear indexing (a single integer) using CartesianIndices.
function Base.getindex(MM::ModeNMatrix, i::Int)
    return MM[CartesianIndices(MM)[i]]
end

The following snippet which uses this struct and the mul! function from LinearAlgebra.jl shows that my struct needs to allocate and is quite slower than when one uses regular matrices. Can you help me understand why that is so?

begin 
             aa = rand(5,5);
             bb = rand(5,5);
             cc = rand(5,5);
             a = matricize(aa,0);
             b = matricize(bb,0);
             c = matricize(cc,0);
             @btime mul!($c,$a,$b);
             @btime mul!($cc,$aa,$bb);
         end;

returns:

  1.753 μs (10 allocations: 20.95 KiB)
  178.536 ns (0 allocations: 0 bytes)

Your code actually does not allocate for me, though mul! for ModeNMatrix is still much slower:

  1.770 μs (0 allocations: 0 bytes)
  144.799 ns (0 allocations: 0 bytes)

The performance difference seems to simply be due to getindex and setindex! being much slower for ModeNMatrix than for Matrix:

julia> @btime sum($c)  # presumably uses linear indexing
  345.833 ns (0 allocations: 0 bytes)
35.801194754045646

julia> @btime sum($cc)  
  9.900 ns (0 allocations: 0 bytes)
35.801194754045646

julia> function my_sum_CI(x)
           s = 0.
           for i in CartesianIndices(x)
               s += x[i]
           end
           return s
       end;

julia> @btime my_sum_CI($c)  # 2D indexing, probably more relevant for mul!
  118.551 ns (0 allocations: 0 bytes)
35.801194754045646

julia> @btime my_sum_CI($cc)
  15.315 ns (0 allocations: 0 bytes)
35.801194754045646

julia> @btime $c .= 0;
  325.551 ns (0 allocations: 0 bytes)

julia> @btime $cc .= 0;
  11.812 ns (0 allocations: 0 bytes)

As a sidenote:

julia> c[1, 1] = 0
ERROR: StackOverflowError:
Stacktrace:
      [1] setindex!(MM::ModeNMatrix{Float64, 2, 1, Matrix{Float64}}, value::Int64, i::Int64)
        @ Main .\REPL[20]:3
      [2] _setindex!
        @ .\abstractarray.jl:1436 [inlined]
      [3] setindex!(::ModeNMatrix{Float64, 2, 1, Matrix{Float64}}, ::Int64, ::Int64, ::Int64)
        @ Base .\abstractarray.jl:1413
--- the above 3 lines are repeated 39990 more times ---
 [119974] setindex!(MM::ModeNMatrix{Float64, 2, 1, Matrix{Float64}}, value::Int64, i::Int64)
        @ Main .\REPL[20]:3
 [119975] _setindex!
        @ .\abstractarray.jl:1436 [inlined]

c[1, 1] = 0. does work fine.

versioninfo
julia> versioninfo()
Julia Version 1.11.2
Commit 5e9a32e7af (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Windows (x86_64-w64-mingw32)
  CPU: 8 × Intel(R) Core(TM) i7-7700K CPU @ 4.20GHz
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, skylake)
Threads: 8 default, 0 interactive, 4 GC (on 8 virtual cores)
Environment:
  JULIA_NUM_THREADS = auto
1 Like

I guessed that the performance would be slower given the custom getindex which requires some flops, but I am still puzzled with the allocations. I guess being in a previous version might have something to do with it. I’ll upgrade and see what I can do, thank you!

Version Info
In [31]: versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × Intel(R) Core(TM) i5-6440HQ CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores
1 Like