I am trying to write an efficient matricization function following Kolda’s and Bader’s work on “Tensor Decompositions and Applications”. The struct I have is non-allocating when initialized as it is basically a lazy view of the array provided. Here is my implementation.
# Inserts the mode‑n index into a tuple of non‑n indices.
@inline function _insert_index(nonmode_idx::NTuple{M,Int}, i::Int, n::Int) where {M}
N = M + 1
return @inbounds ntuple(k -> k == n ? i : (k < n ? nonmode_idx[k] : nonmode_idx[k-1]), N)
end
@inline function _tuple_prod(nonmode_dims::NTuple{M,Int}, n::Int) where {M}
s = 1
@inbounds for i in 1:n
s *= nonmode_dims[i]
end
return s
end
# --- Definition of ModeNMatrix ---
"""
struct ModeNMatrix{T, N, M, A<:AbstractArray{T,N}} <: AbstractMatrix{T}
A lazy view that matricizes an N‑dimensional tensor `X` along mode `n`
without copying data. The resulting matrix has size
(dims[n], (∏dims)/dims[n])
where `dims` are the dimensions of the original tensor.
The fields `nonmode` and `col_strides` are stored as NTuples (length M = N-1)
to avoid heap allocations.
"""
struct ModeNMatrix{T, N, M, A<:AbstractArray{T,N}} <: AbstractMatrix{T}
X::A # The original tensor.
n::Int # The mode along which we matricize.
dims::NTuple{N,Int} # The dimensions of the original tensor.
nonmode::NTuple{M,Int} # The modes other than n.
col_strides::NTuple{M,Int} # Precomputed strides for the non‑n modes.
end
"""
ModeNMatrix(X::AbstractArray{T,N}, n::Int)
Construct a lazy matricized view of the N‑dimensional tensor `X` along mode `n`.
This computes the non‑n mode ordering and column strides as NTuples.
"""
function ModeNMatrix(X::AbstractArray{T,N}, n::Int) where {T, N}
@assert 0 ≤ n ≤ N "Mode n must be between 0 and N."
n = n + 1
M = N - 1 # number of modes other than n
dims = size(X) # dims is a NTuple{N,Int}
# The natural ordering for non‑n modes: (1, 2, …, n-1, n+1, …, N)
nonmode = ntuple(i -> i < n ? i : i + 1, M)
# Extract the dimensions for the non‑n modes.
nonmode_dims = @views @inbounds ntuple(i -> dims[i < n ? i : i + 1], M)
# Compute static column strides:
# For the 1st non‑n mode, stride = 1; for k ≥ 2, stride[k] = prod(nonmode_dims[1:k-1]).
col_strides = @inbounds ntuple(i -> i == 1 ? 1 : _tuple_prod(nonmode_dims, i - 1), M)# prod(@views nonmode_dims[1:i-1]), M)
return ModeNMatrix{T, N, M, typeof(X)}(X, n, dims, nonmode, col_strides)
end
# Convenience function.
matricize(X::AbstractArray, n::Int) = ModeNMatrix(X, n)
# size: rows = dims[n], columns = (∏dims)/dims[n]
Base.size(MM::ModeNMatrix{T, N, M, A}) where {T, N, M, A} =
(MM.dims[MM.n], div(Base.prod(MM.dims), MM.dims[MM.n]))
# axes: we define the index ranges for rows and columns.
Base.axes(MM::ModeNMatrix) = (Base.OneTo(size(MM,1)), Base.OneTo(size(MM,2)))
# Let’s declare the preferred index style.
Base.IndexStyle(::Type{<:ModeNMatrix}) = IndexLinear()
# Primary getindex: when given two integer indices.
function Base.getindex(MM::ModeNMatrix{T, N, M, A}, i::Int, j::Int) where {T, N, M, A}
j_adj = j - 1 # convert j to 0-based indexing for the non‑n modes
# Compute the indices for each non‑n mode (stored in a NTuple).
nonmode_idx = @inbounds ntuple(k -> (j_adj ÷ MM.col_strides[k]) % MM.dims[MM.nonmode[k]] + 1, M)
# Insert the mode‑n index into its proper place.
full_idx = _insert_index(nonmode_idx, i, MM.n)
return MM.X[full_idx...]
end
function Base.setindex!(MM::ModeNMatrix{T, N, M, A}, value::T, i::Int, j::Int) where {T, N, M, A}
j_adj = j - 1 # Convert column index to 0-based for stride computation
# Compute the non-mode indices using the column index
nonmode_idx = ntuple(k -> (j_adj ÷ MM.col_strides[k]) % MM.dims[MM.nonmode[k]] + 1, M)
# Reconstruct the full index for the original tensor
full_idx = ntuple(k -> k == MM.n ? i : (k < MM.n ? nonmode_idx[k] : nonmode_idx[k-1]), N)
# Assign the value to the original tensor
@inbounds MM.X[full_idx...] = value
end
function Base.setindex!(MM::ModeNMatrix, value, i::Int)
ci = CartesianIndices(MM)[i]
MM[ci[1], ci[2]] = value
end
# Define getindex for CartesianIndex{2} (delegates to the two-integer version).
Base.getindex(MM::ModeNMatrix, I::CartesianIndex{2}) = MM[I[1], I[2]]
# Also define a fallback for tuple indexing.
function Base.getindex(MM::ModeNMatrix, inds::Tuple{Vararg{Int}})
if length(inds) == 2
return MM[inds[1], inds[2]]
else
return MM[CartesianIndex(inds)]
end
end
# Finally, define linear indexing (a single integer) using CartesianIndices.
function Base.getindex(MM::ModeNMatrix, i::Int)
return MM[CartesianIndices(MM)[i]]
end
The following snippet which uses this struct and the mul!
function from LinearAlgebra.jl
shows that my struct needs to allocate and is quite slower than when one uses regular matrices. Can you help me understand why that is so?
begin
aa = rand(5,5);
bb = rand(5,5);
cc = rand(5,5);
a = matricize(aa,0);
b = matricize(bb,0);
c = matricize(cc,0);
@btime mul!($c,$a,$b);
@btime mul!($cc,$aa,$bb);
end;
returns:
1.753 μs (10 allocations: 20.95 KiB)
178.536 ns (0 allocations: 0 bytes)