# Help writing a `collect` for iterators

Here is a task that I am having a hard time to accomplish with Julia Base:

Given an iterator `iter` and a set of linear indices `inds`, write a function `collectat(iter, inds)` that collects the iterator at the indices without materializing the full list of items.

I asked this question on Zulip some time ago and @Mason helped with a solution that depends on Transducers.jl:

``````using Transducers: Filter, Map, TakeWhile, tcollect, ⨟

function collectat(iter, inds)
if isempty(inds)
eltype(iter)[]
else
selectat(inds) = enumerate ⨟ TakeWhile(x -> first(x) ≤ last(inds)) ⨟ Filter(y -> first(y) ∈ inds) ⨟ Map(last)
iter |> selectat(inds) |> tcollect
end
end
``````

Here is an example of usage:

``````julia> iter = (rand() for _ in 1:200_000_000)
Base.Generator{UnitRange{Int64}, var"#8#9"}(var"#8#9"(), 1:200000000)

julia> inds = [1,1_000_000,1_999_999]
3-element Vector{Int64}:
1
1000000
1999999

julia> @allocated collectat(iter, inds)
4608
``````

Is it possible to write something that works in most cases without Transducers.jl?

Is the Transducers solution different from this one using just `Iterators`?

``````collectat_itr(itr, idxs) = @p let
enumerate(itr)
Iterators.takewhile(first(_) ≤ last(idxs))
Iterators.filter(first(_) ∈ idxs)
map(last)
end
``````

2 Likes

@aplavin, how would you write that function without DataPipes.jl?

Like this?

``````function collectat_itr2(itr, idxs)
e = enumerate(itr)
it = Iterators.takewhile(x->(first(x) ≤ last(idxs)), e)
f = Iterators.filter(x->(first(x) ∈ idxs), it)
return map(last, f)
end
``````
1 Like

Yeah, exactly – the translation back and forth is really straightforward (:
Note that some parens in your example aren’t necessary, eg `x->(first(x) ≤ last(idxs))` could be just `x->first(x) ≤ last(idxs)`.

1 Like

That is perfect! Thank you all!

I think the solution returns a different number of elements depending on the order of the indices. If that’s not what you want I guess you’d do something like `collectat_itr3`? It returns the correct number of elements but ignores the order of indices (preserving the order of elements in the iterator).

Edit: I guess when the OP said linear indices, they meant ordered indices, in which case `collectat_itr2` makes sense. I’ll still leave my comment here to draw attention to it.

``````struct I
n::Int
I(n::Int) = n >= 0 ? new(n) : error("Expected integer >=0. Got: \$n")
end
Iterators.iterate(i::I, state::Int = i.n) = state == 0 ? nothing : (state, state - 1)

function collectat_itr2(itr, idxs)
e = enumerate(itr)
it = Iterators.takewhile(x->(first(x) ≤ last(idxs)), e)
f = Iterators.filter(x->(first(x) ∈ idxs), it)
return map(last, f)
end

function collectat_itr3(itr, idxs)
e = enumerate(itr)
f = Iterators.filter(x->(first(x) ∈ idxs), e)
return map(last, f)
end

@assert collectat_itr2(I(9), [1, 2]) == [9, 8]
@assert collectat_itr2(I(9), [2, 1]) == [9]

@assert collectat_itr3(I(9), [1, 2]) == [9, 8]
@assert collectat_itr3(I(9), [2, 1]) == [9, 8]
``````
1 Like

Good point, for unordered indices, I would probably do something like this:

``````function collectat_itr4(itr, idxs)
idxs = Set(idxs) # Want to check unordered membership
itr |> it ->
Iterators.take(it, maximum(idxs)) |>  # just take what we might need
enumerate |> it ->
Iterators.filter(x->(first(x) ∈ idxs), it) |> it ->
map(last, it)
end
``````

Using `take` or `takewhile` also has the nice feature that it works on infinite iterators then, i.e., only taking the finite indices that are requested.

1 Like

Thank you for raising the issue @sadish-d. By linear indices I meant `LinearIndices`. I will update the answer to make sure that all elements are returned.

Final solution:

``````function collectat(iter, inds)
if isempty(inds)
eltype(iter)[]
else
m = maximum(inds)
e = Iterators.enumerate(iter)
w = Iterators.takewhile(x -> (first(x) ≤ m), e)
f = Iterators.filter(x -> (first(x) ∈ inds), w)
map(last, f)
end
end
``````
4 Likes

This still preserves the order of `iter` and ignores the order of indices in `inds`. It does not iterate over `inds` one by one, it iterates over `iter`.

You are correct @sadish-d , I removed the word “ordered” in my previous comment before the solution to avoid confusion.

In case the ordered indicators are important (and with slight improvement for non-ordered case):

``````function collectat2(iter, inds)
isempty(inds) && return eltype(iter)[]
wassorted = issorted(inds)
if !wassorted
perm = sortperm(inds)
iperm = invperm(perm)
end
L, T, j = length(inds), eltype(iter), 1
M = wassorted ? last(inds) : inds[last(perm)]
res = Vector{T}(undef, L)
for (i,v) in enumerate(Iterators.take(iter, M))
if i == inds[wassorted ? j : perm[j]]
res[wassorted ? j : iperm[j]] = v
j == L && break
j += 1
end
end
return res
end
``````

and some timings:

``````julia> @btime collectat(1:10,[2,3,4,5]);
104.165 ns (4 allocations: 304 bytes)

julia> @btime collectat2(1:10,[2,3,4,5]);
34.675 ns (2 allocations: 192 bytes)
``````