# How to extract sub-matrices (or their indexes) of Block Diagonal Matrix?

Is there a function in Julia to return(extract) the sub-matrices (or their indexes) from Block Diagonal Matrix similar to here Extract matrices from a Block Diagonal Matrix??
For example, I need to return `A[1:3,1:3]`, `A[4:5,4:5]`, and `A[6:end,6:end] i.e., or indexes of (3,3) and (5,5).

``````julia> A = sparse([5.0 1 2 0 0 0 0; 2 1 5 0 0 0 0; 5 6 7 0 0 0 0; 0 0 0 2 -1 0 0; 0 0 0 5 6 0 0; 0 0 0 0 0 10 11; 0 0 0 0 0 2 3])
7×7 SparseMatrixCSC{Float64, Int64} with 17 stored entries:
5.0  1.0  2.0   ⋅     ⋅     ⋅     ⋅
2.0  1.0  5.0   ⋅     ⋅     ⋅     ⋅
5.0  6.0  7.0   ⋅     ⋅     ⋅     ⋅
⋅    ⋅    ⋅   2.0  -1.0    ⋅     ⋅
⋅    ⋅    ⋅   5.0   6.0    ⋅     ⋅
⋅    ⋅    ⋅    ⋅     ⋅   10.0  11.0
⋅    ⋅    ⋅    ⋅     ⋅    2.0   3.0
``````

I don’t know of any built-in function to extract the locations of the diagonal blocks, but you could always write your own, e.g. here is something simple (that assumes `A` is block-diagonal):

``````using SparseArrays, LinearAlgebra
function diagblocks(A::SparseArrays.AbstractSparseMatrixCSC)
m = LinearAlgebra.checksquare(A)
rows = rowvals(A)
blocks = Matrix{eltype(A)}[]
prevj = 1 # index of previous block's first column
for j = 1:m
if j > prevj && rows[first(nzrange(A, j))] >= j # new block starting
push!(blocks, Matrix(@view A[prevj:j-1, prevj:j-1]))
prevj = j
end
end
push!(blocks, Matrix(@view A[prevj:m, prevj:m])) # last block
return blocks
end
``````

which gives:

``````julia> display.(diagblocks(A));
3×3 Matrix{Float64}:
5.0  1.0  2.0
2.0  1.0  5.0
5.0  6.0  7.0
2×2 Matrix{Float64}:
2.0  -1.0
5.0   6.0
2×2 Matrix{Float64}:
10.0  11.0
2.0   3.0
``````

1 Like

If you know the sizes of the blocks, you may use existing packages to simplify your code:

``````julia> using BlockArrays, LinearAlgebra

julia> B = BlockMatrix(A, [3,2,2], [3,2,2]);

julia> diag(blocks(B))
3-element Vector{SparseMatrixCSC{Float64, Int64}}:
sparse([1, 2, 3, 1, 2, 3, 1, 2, 3], [1, 1, 1, 2, 2, 2, 3, 3, 3], [5.0, 2.0, 5.0, 1.0, 1.0, 6.0, 2.0, 5.0, 7.0], 3, 3)
sparse([1, 2, 1, 2], [1, 1, 2, 2], [2.0, 5.0, -1.0, 6.0], 2, 2)
sparse([1, 2, 1, 2], [1, 1, 2, 2], [10.0, 2.0, 11.0, 3.0], 2, 2)
``````
1 Like

Got it, thank you very much