Here is a task that I am having a hard time to accomplish with Julia Base:
Given an iterator iter and a set of linear indices inds, write a function collectat(iter, inds) that collects the iterator at the indices without materializing the full list of items.
I asked this question on Zulip some time ago and @Mason helped with a solution that depends on Transducers.jl:
using Transducers: Filter, Map, TakeWhile, tcollect, ⨟
function collectat(iter, inds)
if isempty(inds)
eltype(iter)[]
else
selectat(inds) = enumerate ⨟ TakeWhile(x -> first(x) ≤ last(inds)) ⨟ Filter(y -> first(y) ∈ inds) ⨟ Map(last)
iter |> selectat(inds) |> tcollect
end
end
Here is an example of usage:
julia> iter = (rand() for _ in 1:200_000_000)
Base.Generator{UnitRange{Int64}, var"#8#9"}(var"#8#9"(), 1:200000000)
julia> inds = [1,1_000_000,1_999_999]
3-element Vector{Int64}:
1
1000000
1999999
julia> @allocated collectat(iter, inds)
4608
Is it possible to write something that works in most cases without Transducers.jl?
@aplavin, how would you write that function without DataPipes.jl?
Like this?
function collectat_itr2(itr, idxs)
e = enumerate(itr)
it = Iterators.takewhile(x->(first(x) ≤ last(idxs)), e)
f = Iterators.filter(x->(first(x) ∈ idxs), it)
return map(last, f)
end
Yeah, exactly – the translation back and forth is really straightforward (:
Note that some parens in your example aren’t necessary, eg x->(first(x) ≤ last(idxs)) could be just x->first(x) ≤ last(idxs).
I think the solution returns a different number of elements depending on the order of the indices. If that’s not what you want I guess you’d do something like collectat_itr3? It returns the correct number of elements but ignores the order of indices (preserving the order of elements in the iterator).
Edit: I guess when the OP said linear indices, they meant ordered indices, in which case collectat_itr2 makes sense. I’ll still leave my comment here to draw attention to it.
struct I
n::Int
I(n::Int) = n >= 0 ? new(n) : error("Expected integer >=0. Got: $n")
end
Iterators.iterate(i::I, state::Int = i.n) = state == 0 ? nothing : (state, state - 1)
function collectat_itr2(itr, idxs)
e = enumerate(itr)
it = Iterators.takewhile(x->(first(x) ≤ last(idxs)), e)
f = Iterators.filter(x->(first(x) ∈ idxs), it)
return map(last, f)
end
function collectat_itr3(itr, idxs)
e = enumerate(itr)
f = Iterators.filter(x->(first(x) ∈ idxs), e)
return map(last, f)
end
@assert collectat_itr2(I(9), [1, 2]) == [9, 8]
@assert collectat_itr2(I(9), [2, 1]) == [9]
@assert collectat_itr3(I(9), [1, 2]) == [9, 8]
@assert collectat_itr3(I(9), [2, 1]) == [9, 8]
Good point, for unordered indices, I would probably do something like this:
function collectat_itr4(itr, idxs)
idxs = Set(idxs) # Want to check unordered membership
itr |> it ->
Iterators.take(it, maximum(idxs)) |> # just take what we might need
enumerate |> it ->
Iterators.filter(x->(first(x) ∈ idxs), it) |> it ->
map(last, it)
end
Using take or takewhile also has the nice feature that it works on infinite iterators then, i.e., only taking the finite indices that are requested.
Thank you for raising the issue @sadish-d. By linear indices I meant LinearIndices. I will update the answer to make sure that all elements are returned.
function collectat(iter, inds)
if isempty(inds)
eltype(iter)[]
else
m = maximum(inds)
e = Iterators.enumerate(iter)
w = Iterators.takewhile(x -> (first(x) ≤ m), e)
f = Iterators.filter(x -> (first(x) ∈ inds), w)
map(last, f)
end
end
In case the ordered indicators are important (and with slight improvement for non-ordered case):
function collectat2(iter, inds)
isempty(inds) && return eltype(iter)[]
wassorted = issorted(inds)
if !wassorted
perm = sortperm(inds)
iperm = invperm(perm)
end
L, T, j = length(inds), eltype(iter), 1
M = wassorted ? last(inds) : inds[last(perm)]
res = Vector{T}(undef, L)
for (i,v) in enumerate(Iterators.take(iter, M))
if i == inds[wassorted ? j : perm[j]]
res[wassorted ? j : iperm[j]] = v
j == L && break
j += 1
end
end
return res
end