Inplace matrix sampling

I wonder is there an algorithm, established package or some other way for inplace matrix sampling? Here is an example of what I am trying to achieve

using StableRNGs
using StatsBase
using BenchmarkTools

rng = StableRNG(2020)
a = rand(rng, 100, 500);
a1 = copy(a)
ids = sample(rng, axes(a, 2), size(a, 2));

It is easy to build new sampled matrix in the following manner

f1(a, ids) = a[:, ids]

but it allocates

julia> @btime f1($a, $ids);
  39.015 μs (2 allocations: 390.70 KiB)

I can use @views to get view of the sampled matrix, but unfortunately, later I need to change this matrix where each column should be treated independently, so @view is not exactly what is needed.
On the other hand, simple assignment is too slow

function f2!(a, ids)
    b = @views a[:, ids]
    a .= b
end

julia> @btime f2!(x, $ids) setup=(x = copy(a)) evals = 1;
  196.959 μs (3 allocations: 394.77 KiB)

and naive handwritten assignment is of course not working

function f3!(a, ids)
    b = @views a[:, ids]
    for j in axes(b, 2)
        for i in axes(b, 1)
            a[i, j] = b[i, j]
        end
    end
end

julia> b = f1(a, ids);
julia> f3!(a, ids);
julia> @assert a == b
ERROR: AssertionError: a == b

It seems that it can be solved, but I can’t wrap my head around it.

Well, it looks ugly, but it works :slight_smile:
Algorithm is the following

  1. we sort our indices
  2. we run over indices. If our current position is less then new index (i.e we should copy future row) we do exactly this.
  3. If our current position is larger then new index (we should copy from the past) we memorize that index and go to the next index
  4. When our current position is again less or equal to the new index or we hit the end of the matrix, we copy columns in backward order (from right to the left).
function f4(a, ids)
    a1 = copy(a)
    ids1 = copy(ids)
    f4!(a1, ids1)
end

function f4!(a, ids)
    sort!(ids)
    rev = false
    start = 0
    @inbounds for (i, id) in enumerate(ids)
        if i < id
            if !rev
                for k in axes(a, 1)
                    a[k, i] = a[k, id]
                end
            else
                rev = false
                for m in i:-1:start
                    for k in axes(a, 1)
                        a[k, m] = a[k, ids[m]]
                    end
                end
            end
        elseif i > id
            if !rev
                rev = true
                start = i
            end
        else # i == id
            if rev
                rev = false
                for m in i:-1:start
                    for k in axes(a, 1)
                        a[k, m] = a[k, ids[m]]
                    end
                end
            end
        end
    end

    @inbounds if rev
        for m in size(a, 2):-1:start
            for k in axes(a, 1)
                a[k, m] = a[k, ids[m]]
            end
        end
    end

    return a
end

and we can verify that it works

rng = StableRNG(2020)
a = rand(rng, 100, 500);
ids = sample(rng, axes(a, 2), size(a, 2));

b1 = f1(a, sort(ids)); # note `sort` here, since our f4 algorithm has different order compare to ids
b2 = f4(a, ids);
@assert b1 == b2

And benchmark shows x2 improvement!

julia> @btime f4!(x, y) setup=((x, y) = (copy(a), copy(ids))) evals = 1;
  21.120 μs (0 allocations: 0 bytes)

On a side note: I wanted to understand how it can work, because this operation is a common thing in particle filters (particularly in resampling part). So, this approach can speed them up, which is a good thing, I suppose.

1 Like