Amazing Jared.
In prune_matrix
you can leverage the fact that the SparseMatrixCSC you are using only contains ones.
Instead of this
sum(X_csc; dims=1)
You can look at the pointers of the array and substract the consecutive values. This will tell you the amount of nonzero values in a column (which in your case equals to the sum of the elements in the column, because they are all 1s).
For example consider the following X_csc
X_csc = sparse([0 1 0; 1 0 1; 0 1 1; 1 1 0; 1 1 1; 1 1 0])
6×3 SparseMatrixCSC{Int64, Int64} with 12 stored entries:
⋅ 1 ⋅
1 ⋅ 1
⋅ 1 1
1 1 ⋅
1 1 1
1 1 ⋅
This has
X_csc.colptr
4-element Vector{Int64}:
1
5
10
13
Since your matrix has only ones, 5-1 = 4, 10-5 =5, 13-10=3 will tell you the sum over the columns, without looking at the values.
This is the implementation
function get_supports_from_transaction_csc(X_csc::SparseMatrixCSC)
n_result = size(X_csc, 2)
vec_result = Vector{Int64}(undef, n_result)
previous_elem = X_csc.colptr[1]
for (i,x) in enumerate(X_csc.colptr[2:end])
vec_result[i] = x-previous_elem
previous_elem = x
end
return vec_result
end
Benchmark
# Create a 10000 x 10000 sparse matrix with a density of 0.01 (1% of elements are non-zero)
m = 100000
n = 100
density = 0.01
A_csc = sprand(m, n, density)
A_csc.nzval .= 1;
@btime get_supports_from_transaction_csc(A_csc)
132.106 ns (2 allocations: 1.75 KiB)
@btime sum(A_csc; dims=1)
18.208 μs (4 allocations: 1.02 KiB)
Option 1) Finding sorted items without storing all supports
You can use the idea above to change this
supports = sum(matrix, dims=1)
sorted_items = [i for i in axes(matrix,2) if supports[1,i] >= min_support]
By this
function get_sorted_items(X_csc::SparseMatrixCSC, min_support)
sorted_items = Vector{Int64}()
supports = Vector{Int64}()
previous_elem = X_csc.colptr[1]
for (i,x) in enumerate(X_csc.colptr[2:end])
current_support = x-previous_elem
previous_elem = x
if current_support >= min_support
push!(sorted_items, i)
push!(supports, current_support)
end
end
return sorted_items, supports
end
Option 2) Finding sorted items storing all supports
function get_sorted_items2(X_csc::SparseMatrixCSC, min_support)
supports = get_supports_from_transaction_csc(X_csc)
sorted_items = [i for i in axes(X_csc,2) if supports[i] >= min_support]
return sorted_items, supports
end
It seems a bit faster to compute but at the expense of storing all supports.
Benchmark
Current version
function find_sorted_items_current(matrix::SparseMatrixCSC, min_support)
supports = sum(matrix, dims=1)
sorted_items = [i for i in axes(matrix,2) if supports[1,i] >= min_support]
return sorted_items, supports
end
@btime find_sorted_items_current(A_csc, 4)
18.833 μs (10 allocations: 3.03 KiB)
Options proposed
@btime get_sorted_items(A_csc, 4)
950.818 ns (10 allocations: 4.75 KiB)
@btime get_sorted_items2(A_csc, 4)
645.078 ns (8 allocations: 3.77 KiB)
If you want I can do a PR during the week, but feel free to update yourself if this makes sense to you.