Inplace matrix sampling

Well, it looks ugly, but it works :slight_smile:
Algorithm is the following

  1. we sort our indices
  2. we run over indices. If our current position is less then new index (i.e we should copy future row) we do exactly this.
  3. If our current position is larger then new index (we should copy from the past) we memorize that index and go to the next index
  4. When our current position is again less or equal to the new index or we hit the end of the matrix, we copy columns in backward order (from right to the left).
function f4(a, ids)
    a1 = copy(a)
    ids1 = copy(ids)
    f4!(a1, ids1)
end

function f4!(a, ids)
    sort!(ids)
    rev = false
    start = 0
    @inbounds for (i, id) in enumerate(ids)
        if i < id
            if !rev
                for k in axes(a, 1)
                    a[k, i] = a[k, id]
                end
            else
                rev = false
                for m in i:-1:start
                    for k in axes(a, 1)
                        a[k, m] = a[k, ids[m]]
                    end
                end
            end
        elseif i > id
            if !rev
                rev = true
                start = i
            end
        else # i == id
            if rev
                rev = false
                for m in i:-1:start
                    for k in axes(a, 1)
                        a[k, m] = a[k, ids[m]]
                    end
                end
            end
        end
    end

    @inbounds if rev
        for m in size(a, 2):-1:start
            for k in axes(a, 1)
                a[k, m] = a[k, ids[m]]
            end
        end
    end

    return a
end

and we can verify that it works

rng = StableRNG(2020)
a = rand(rng, 100, 500);
ids = sample(rng, axes(a, 2), size(a, 2));

b1 = f1(a, sort(ids)); # note `sort` here, since our f4 algorithm has different order compare to ids
b2 = f4(a, ids);
@assert b1 == b2

And benchmark shows x2 improvement!

julia> @btime f4!(x, y) setup=((x, y) = (copy(a), copy(ids))) evals = 1;
  21.120 Ξs (0 allocations: 0 bytes)

On a side note: I wanted to understand how it can work, because this operation is a common thing in particle filters (particularly in resampling part). So, this approach can speed them up, which is a good thing, I suppose.

1 Like