Help writing a `collect` for iterators

Here is a task that I am having a hard time to accomplish with Julia Base:

Given an iterator iter and a set of linear indices inds, write a function collectat(iter, inds) that collects the iterator at the indices without materializing the full list of items.

I asked this question on Zulip some time ago and @Mason helped with a solution that depends on Transducers.jl:

using Transducers: Filter, Map, TakeWhile, tcollect, ⨟

function collectat(iter, inds)
  if isempty(inds)
    eltype(iter)[]
  else
    selectat(inds) = enumerate ⨟ TakeWhile(x -> first(x) ≤ last(inds)) ⨟ Filter(y -> first(y) ∈ inds) ⨟ Map(last)
    iter |> selectat(inds) |> tcollect
  end
end

Here is an example of usage:

julia> iter = (rand() for _ in 1:200_000_000)
Base.Generator{UnitRange{Int64}, var"#8#9"}(var"#8#9"(), 1:200000000)

julia> inds = [1,1_000_000,1_999_999]
3-element Vector{Int64}:
       1
 1000000
 1999999

julia> @allocated collectat(iter, inds)
4608

Is it possible to write something that works in most cases without Transducers.jl?

Is the Transducers solution different from this one using just Iterators?

collectat_itr(itr, idxs) = @p let
    enumerate(itr)
    Iterators.takewhile(first(_) ≤ last(idxs))
    Iterators.filter(first(_) ∈ idxs)
    map(last)
end

(copied from that zulip thread)

2 Likes

@aplavin, how would you write that function without DataPipes.jl?

Like this?

function collectat_itr2(itr, idxs)
    e = enumerate(itr)
    it = Iterators.takewhile(x->(first(x) ≤ last(idxs)), e)
    f = Iterators.filter(x->(first(x) ∈ idxs), it)
    return map(last, f)
end
1 Like

Yeah, exactly – the translation back and forth is really straightforward (:
Note that some parens in your example aren’t necessary, eg x->(first(x) ≤ last(idxs)) could be just x->first(x) ≤ last(idxs).

1 Like

That is perfect! Thank you all!

I think the solution returns a different number of elements depending on the order of the indices. If that’s not what you want I guess you’d do something like collectat_itr3? It returns the correct number of elements but ignores the order of indices (preserving the order of elements in the iterator).

Edit: I guess when the OP said linear indices, they meant ordered indices, in which case collectat_itr2 makes sense. I’ll still leave my comment here to draw attention to it.

struct I
    n::Int
    I(n::Int) = n >= 0 ? new(n) : error("Expected integer >=0. Got: $n")
end
Iterators.iterate(i::I, state::Int = i.n) = state == 0 ? nothing : (state, state - 1)


function collectat_itr2(itr, idxs)
    e = enumerate(itr)
    it = Iterators.takewhile(x->(first(x) ≤ last(idxs)), e)
    f = Iterators.filter(x->(first(x) ∈ idxs), it)
    return map(last, f)
end

function collectat_itr3(itr, idxs)
    e = enumerate(itr)
    f = Iterators.filter(x->(first(x) ∈ idxs), e)
    return map(last, f)
end

@assert collectat_itr2(I(9), [1, 2]) == [9, 8]
@assert collectat_itr2(I(9), [2, 1]) == [9]

@assert collectat_itr3(I(9), [1, 2]) == [9, 8]
@assert collectat_itr3(I(9), [2, 1]) == [9, 8]
1 Like

Good point, for unordered indices, I would probably do something like this:

function collectat_itr4(itr, idxs)
    idxs = Set(idxs) # Want to check unordered membership
    itr |> it ->
      Iterators.take(it, maximum(idxs)) |>  # just take what we might need
      enumerate |> it ->
      Iterators.filter(x->(first(x) ∈ idxs), it) |> it ->
      map(last, it)
end

Using take or takewhile also has the nice feature that it works on infinite iterators then, i.e., only taking the finite indices that are requested.

1 Like

Thank you for raising the issue @sadish-d. By linear indices I meant LinearIndices. I will update the answer to make sure that all elements are returned.

Final solution:

function collectat(iter, inds)
  if isempty(inds)
    eltype(iter)[]
  else
    m = maximum(inds)
    e = Iterators.enumerate(iter)
    w = Iterators.takewhile(x -> (first(x) ≤ m), e)
    f = Iterators.filter(x -> (first(x) ∈ inds), w)
    map(last, f)
  end
end
4 Likes

This still preserves the order of iter and ignores the order of indices in inds. It does not iterate over inds one by one, it iterates over iter.

You are correct @sadish-d , I removed the word “ordered” in my previous comment before the solution to avoid confusion.

In case the ordered indicators are important (and with slight improvement for non-ordered case):

function collectat2(iter, inds)
    isempty(inds) && return eltype(iter)[]
    wassorted = issorted(inds)
    if !wassorted
        perm = sortperm(inds)
        iperm = invperm(perm)
    end
    L, T, j = length(inds), eltype(iter), 1
    M = wassorted ? last(inds) : inds[last(perm)]
    res = Vector{T}(undef, L)
    for (i,v) in enumerate(Iterators.take(iter, M))
        if i == inds[wassorted ? j : perm[j]]
            res[wassorted ? j : iperm[j]] = v
            j == L && break
            j += 1
        end
    end
    return res
end

and some timings:

julia> @btime collectat(1:10,[2,3,4,5]);
  104.165 ns (4 allocations: 304 bytes)

julia> @btime collectat2(1:10,[2,3,4,5]);
  34.675 ns (2 allocations: 192 bytes)