I have a collection of (potentially infinite) sorted iterators I1, I2, I3, … and a nondecreasing value function sort_by((x1, x2, x3, ...)). I want the sorted cartesian product of these iterators.
I’m thinking I can do this with a heap (adding a set may in certain edge cases result in a speedup), but I want to see if anyone else already has already implemented a solution to this problem.
(If I1 and I2 are the naturals and sort_by is +, then this provides a constructive proof that N x N is countable.)
I’m a little unclear on what you mean by “sorted iterators” and “sorted product” - sorted by what metric, respectively?
If I understand correctly, sorted_by is your metric and you want the values (i.e. the tuples) produced by the iterators if you were to iterate them simultaneously to be sorted by that metric?
For finite iterators, by sorted I mean issorted(I) returns true. By sorted by sorted_by I mean issorted(product, bt=sorted_by) returns true. For infinite iterators I mean every prefix is sorted.
I want a result that is equivalent to sort(vec(collect(product(I1, I2, I3, ...))), by=sorted_by) but without collecting.
We can assume that the input iterators are sorted and sorted_by is nondecreasing, so we know that the first element of the output should be the first elements of the inputs. I think heapsort can get n log n in this case.
I’ve implemented it. The important bits are in the implementation of the two-argument iterate. I’ve included a bit more for context, and you can see the complete package here.
function SortedIteratorProduct(by::Function, iterators...)
sources = cached.(iterators)
SortedIteratorProduct(sources, by)
end
lookup(sip, x) = tuple((s[i] for (s, i) in zip(sip.sources, x))...)
function Base.iterate(sip::SortedIteratorProduct)
all(x -> checkbounds(Bool, x, 1), sip.sources) || return nothing
one = map(_->1, sip.sources)
iterate(sip, (Set((one,)), BinaryHeap(Base.By(x -> (sip.by(lookup(sip, x)), reverse(x))), [one])))
end
function Base.iterate(sip::SortedIteratorProduct, (set, heap))
isempty(heap) && return nothing
indices = pop!(heap)
for i in eachindex(indices)
new = ntuple(j -> indices[j] + (j == i), length(indices))
if checkbounds(Bool, sip.sources[i], indices[i]+1) && new ∉ set
push!(set, new)
push!(heap, new)
end
end
lookup(sip, indices), (set, heap)
end
Nice! With a bit more bookkeeping you can discard items from set once everything immediately higher in each direction is in set. That is, only keep the frontier. I assume this would typically save n^{1/d} memory.