I have the following goal:
Take in an iterator which does not necessarily have eltype(itr)
defined.
- If the iterator is looping through numbers, return a vector of those numbers.
- If the iterator is not through numbers, return a matrix where each row of the matrix is a
collect
ed element of the iterator.
The most naive implementation is
_collect(x::Number) = x
_collect(x) = collect(x)
function mycollect(itr)
mapreduce(transpose ∘ _collect, vcat, itr)
end
Does anyone have an idea for how to do this more efficiently?
Does the iterator and the iterators inside the iterator have a known length?
no. But we can assume the lengths are all equal.
Then I have a package which will do this:
julia> tups = ((1,i,i^2) for i in Iterators.filter(_->rand(Bool), 1:1000));
julia> Base.haslength(tups)
false
julia> eltype(tups)
Any
julia> using LazyStack
julia> stack(_collect(t) for t in tups)
3×515 Array{Int64,2}:
1 1 1 1 1 1 1 1 1 1 … 1 1 1 1
2 3 8 9 10 11 13 14 15 16 988 989 991 999
4 9 64 81 100 121 169 196 225 256 976144 978121 982081 998001
julia> stack(last(t) for t in tups)' # numbers
1×495 LinearAlgebra.Adjoint{Int64,Array{Int64,1}}:
4 100 121 144 289 529 625 841 900 … 982081 990025 994009 996004
julia> @btime mycollect($tups);
374.261 μs (1770 allocations: 2.38 MiB)
julia> @btime stack(_collect(t) for t in $tups);
36.588 μs (445 allocations: 71.83 KiB)
The need for _collect
here (with tuples) is a bug, only for the case of unknown iterator length, i.e. stack((1,i,i^2) for i in 1:10)
works, but stack(tups)
does not, right now.
2 Likes
Here is a solution
function _c(itr)
n = length(first(itr))
[xj[i] for xj in itr, i in 1:n]
end
EDIT: weirdly enough this does not work with tups
defined above. Not sure why, though.
julia> t = ((1, i, i^2) for i in 1:4)
Base.Generator{UnitRange{Int64},var"#37#38"}(var"#37#38"(), 1:4)
julia> _c(t)
4×3 Array{Int64,2}:
1 1 1
1 2 4
1 3 9
1 4 16
julia> t = ((1, i, i^2) for i in Iterators.filter(isodd, 1:4))
Base.Generator{Base.Iterators.Filter{typeof(isodd),UnitRange{Int64}},var"#39#40"}(var"#39#40"(), Base.Iterators.Filter{typeof(isodd),UnitRange{Int64}}(isodd, 1:4))
julia> _c(t)
6-element Array{Int64,1}:
1
1
1
3
1
9
You can just reshape it:
function c_cols(itr)
n = length(first(itr))
reshape([xj[i] for i in 1:n, xj in itr], n, :)
end
Applied to tuples, this turns out to be much quicker than what stack
is doing (which I think is copyto!
). But slower for vectors.
I still don’t fully understand why adding an Iterators.Filter
messes with the output. It appears to be splatting the tupples, filtering out all the elements of 1, i, i^2
and iterating through the values themselves…
Ah sorry. My implementation gives what I want. I mean that the first row of the matrix is collect(first(itr))
.
There is no splatting, it just doesn’t give shapes to generators built of others whose shape is unknown:
t1 = ((1,2,3) for i in 2:2:10)
t2 = ((1,2,3) for i in Iterators.Filter(iseven, 1:10))
t3 = ((1,2,3) for i in Iterators.Filter(iseven, hcat(1:5, 6:10)))
Base.IteratorSize(t1) # Base.HasShape{1}()
Base.IteratorSize(t2) # Base.SizeUnknown()
Base.IteratorSize(t3) # Base.SizeUnknown()
m1 = (t[i] for i in 1:3, t in t1) # Iterators.product
m2 = (t[i] for i in 1:3, t in t2)
m3 = (t[i] for i in 1:3, t in t3)
Base.IteratorSize(m1) # Base.HasShape{2}()
Base.IteratorSize(m2) # Base.SizeUnknown()
Base.IteratorSize(m3) # Base.SizeUnknown()
You are collecting something like m2
.
But does your real problem contain tuples? These were just the first non-vector objects which came to mind.
Iterators of tuples and named tuples would be useful. Just taking any collection of “observations” and putting it into a matrix where each row is an observation. So it’s helpful to be agnostic about what constitutes an observation.
For things which aren’t tuples, notice that the wrong iteration order is fairly expensive:
@btime c_cols(collect(t) for t in $tups); # 31.931 μs (1016 allocations: 103.06 KiB)
@btime c_rows(collect(t) for t in $tups); # 65.194 μs (1517 allocations: 196.84 KiB)
@btime permutedims(c_cols(collect(t) for t in $tups)); # 32.206 μs (1018 allocations: 103.16 KiB)
julia> function _flat(x)
n = length(first(x))
reshape(collect(Iterators.flatten(x)), :, n)
end
3 Likes