How to get first or last n elements from SortedMultiDict?

Let’s say I want to call a function on the first n elements of sorted multi-dictionary. I can only come up with this clumsy way of doing it:

smdict = SortedMultiDict{Float64, Int64}(0.7=>7, 0.1=>1, 0.4=>4, 0.2=>2)
i = startof(smdict)
n = 3

for _ in 1:n
  global i

  println((deref((smdict, i))))
  i = advance((smdict, i))
end

Is there a more compact way of iterating through an end interval of SortedMultiDict?

The following works, as long as you don’t mind using break

julia> smdict = SortedMultiDict{Float64, Int64}(0.7=>7, 0.1=>1, 0.4=>4, 0.2=>2)
SortedMultiDict(Base.Order.ForwardOrdering(),0.1 => 1, 0.2 => 2, 0.4 => 4, 0.7 => 7)

julia> for (count,i) in enumerate(smdict)
         if count > 3
           break
         end
         println(i)
       end
0.1 => 1
0.2 => 2
0.4 => 4

If you’d prefer not to use break, then perhaps you can find a more elegant using IterTools.jl.

1 Like

I can’t get the last n elements with this method, once the ordering of smdict is defined, right?

To iterate backwards through a SortedMultiDict requires invoking lower level primitives like deref and regress. If you need to do this operation often, then you could make your code cleaner by implementing an iterate protocol for stepping backwards through a SortedMultiDict.

In the C++ standard library, wherever an iterator is defined, the language also defines a reverse iterator if it makes sense. As far as I know, nobody has proposed a similar rule for Julia.

Stepping forward or backward by one is not the preferable way for me. I would like to be able to do iterator arithmetic and use something like this:

for (k,v) in inclusive(smdict, startof(smdict), startof(smdict)+n-1)
  println((k, v))
end

or

for (k,v) in inclusive(smdict, endof(smdict)-n+1, endof(smdict))
  println((k, v))
end

or even better:

for (k,v) in inclusive(smdict, 1:n)
  println((k, v))
end

The current library does not support efficient integer indexing. With a nontrivial amount of additional code, this feature could be supported by adding a field to each internal search-tree node that stores how many descendants it has.

In my use case today the efficiency is not an issue. Convenience is more important. I just realized that I can fill an array with tuples, sort it and get n elements from whichever end I need:

a = [(0.7, 7), (0.1, 1), (0.4, 4), (0.2, 2), (0.1, 1)]
sort!(a)
println(a[1:n])
println(a[end-n+1:end])