I’m making a custom array (SparseSynapses{Npre,Npost}
) that simulates multidimensional sparse tensors on top of SparseMatrixCSC
(for a specific case). To perform the transformation of multidimensional CartesianIndex-es to linear indices that index into the underlying SparseMatrixCSC, I use CartesianIndex
and LinearIndex
. However, implementing just the scalar Base.getindex(S::SparseSynapses{Npre,Npost}, I::Vararg{Int,N}) where {Npre,Npost,N}
method in this case incurs a large performance cost.
(I’m not sure if this is because of the scalar indexing of the LinearIndex or the SparseMatrixCSC)
I accordingly extended the general getindex method:
Base.getindex(S::SparseSynapses{Npre,Npost}, I...) where {Npre,Npost}
I relied on CartesianIndices
for this (see code below), and now it works for scalar and range indices. But not for Array indices, because CartesianIndices can’t be constructed with Arrays (code) and its iteration (code) doesn’t rely on an underlying iterator.
For what works, this approach is indeed much faster (SparseMatrixCSC indexing becomes the bottleneck, instead of LinearIndex construction)
I think my best solution now is to make a custom type that imitates CartesianIndices
for general iterables and delegates the iteration on the encapsulated type. Can you suggest a solution that reuses some of the CartesianIndices
machinery?
I guess my use case is out of scope for the design of CartesianIndices…
Appendix: code
My custom array:
struct SparseSynapses{Npre,Npost} <: AbstractSynapses{Npre,Npost}
data::SparseMatrixCSC{Int8,Int}
preDims::NTuple{Npre,Int}
postDims::NTuple{Npost,Int}
preLinIdx::LinearIndices{Npre}
postLinIdx::LinearIndices{Npost}
function SparseSynapses{Npre,Npost}(data,preDims,postDims) where {Npre,Npost}
preLinIdx= LinearIndices(preDims)
postLinIdx= LinearIndices(postDims)
new{Npre,Npost}(data,preDims,postDims,preLinIdx,postLinIdx);
end
end
Extending the general getindex:
Base.@propagate_inbounds \
function Base.getindex(S::SparseSynapses{Npre,Npost}, I...) where {Npre,Npost}
cartesianIdx(idx::NTuple{N,Int}) where {N}= CartesianIndex(idx)
cartesianIdx(idx::Tuple)= CartesianIndices(idx)
linearIdx(cidx::CartesianIndex, linTransform)::Int= linTransform[cidx]
linearIdx(cidx::CartesianIndices, linTransform)::Vector{Int}= vec(linTransform[cidx])
idx= to_indices(S,I)
S.data[linearIdx(cartesianIdx(idx[1:Npre]), S.preLinIdx),
linearIdx(cartesianIdx(idx[Npre+1:Npre+Npost]), S.postLinIdx)]
end