Iterators.flatmap optimization attempt

Following the nice conversation at #44792 I’m trying to investigate the question of whether flatmap functions might offer performance benefits. I suspect that in general flatten(x) cannot do better than a flatmap(identity, x), because flatmap solves a larger task. Unless the compiler has such a level of sophistication that it actually can treat a flatten(map(...)) as a flatmap. Which is possible, but if this is the case I would imagine having explicit flatmap functions somewhere just makes sense.

I also suspect that most of the time this will come down to specializations for different data-structures, and iterators might actually be the least interesting case. I also imagine Julia is already very much optimized for the simple tasks I’m going to investigate, and it won’t be easy to find an optimization opportunity. It’s still probably a great learning opportunity! I definitely want to know if flatmap functions should be always expected to behave exactly as a flatten(map(...)).

So I set myself to write a Flatmap iterator and this is what I came up with for now. The original version still beats my code by a factor of 15% to 55%. (I basically wrote a double-while-loop and “opened it up”… It’s really annoying.)

I guess there’s one thing we might try to offer a benefit, that would be making a faster check at myit = myf(it) if the iterator is empty. This check would be theoretically faster than calling iterate. Essentially a flatten(filter(!empty,(map(...)))? Just one idea…

Anyways, I’m curious to hear. What problems in my clunky code might be keeping it from having the same performance as the original code? Or what optimizations might be happening in the original code that the new one doesn’t benefit from?

The only hint I found it @code_lowered is that we get these type annotations like Core.PartialStruct(..., Any[...]). Could that be an issue?

using BenchmarkTools
using ProfileView

struct Flatmap{FuncType, Eltype, S}
    myf::FuncType
    input::S
    Flatmap{Eltype}(myf::FuncType, input::S) where {FuncType, Eltype, S} =
        new{typeof(myf), Eltype, Core.Typeof(input)}(myf, input)
end
Flatmap(f, initialstate) = Flatmap{Any}(f, initialstate)

Base.eltype(::Type{Flatmap{F, Eltype, S}}) where {F, Eltype, S} = Eltype
Base.eltype(::Type{<:Flatmap{nothing}}) = Any
Base.IteratorEltype(::Type{<:Flatmap{nothing}}) = Base.EltypeUnknown()
Base.IteratorEltype(::Type{<:Flatmap}) = Base.HasEltype()
Base.IteratorSize(::Type{<:Flatmap}) = Base.SizeUnknown()

function Base.iterate(myfinput::Flatmap{F, Eltype, S}) where {F, Eltype, S}
    myf = myfinput.myf
    input = myfinput.input
    
    itst = iterate(input)
    if isnothing(itst)
        return nothing
    end
    itst = itst::Tuple{Any, Any}
    it, st = itst
    myit = myf(it)
    newoutxst = iterate(myit)
    while isnothing(newoutxst)
        itst = iterate(input, st)
        if isnothing(itst)
            return nothing
        end
        it, st = itst
        myit = myf(it)
        newoutxst = iterate(myit)
    end
    
    newout, xst = newoutxst
    return newout, (st,myit,xst)    
end            

function Base.iterate(myfinput::Flatmap{F, Eltype, S}, (st,myit,xst)) where {F, Eltype, S}
    @inbounds begin
        newoutxst = iterate(myit, xst)

        if !isnothing(newoutxst)
            newout, xst = newoutxst
            return newout, (st,myit,xst)
        end

        myf = myfinput.myf
        input = myfinput.input

        newoutxst = nothing
        while isnothing(newoutxst)
            itst = iterate(input, st)
            itst === nothing && return nothing
            itst = itst::Tuple{Any, Any}
            it, st = itst
            myit = myf(it)
            newoutxst = iterate(myit)
        end
        
        newoutxst = newoutxst
        newout, xst = newoutxst
        return newout, (st,myit,xst)
    end
end            

@inline myf((a,b,c)) = a:b:c

input = map(1:1111) do _
    # a,b = sort(randn(2))
    a,b = randn(2)
    (a, rand()*3, b)
end

@btime sum(Iterators.flatten(Iterators.map(myf, input)))
@btime sum(Flatmap{Float64}(myf, input))

input = map(1:1111) do _
    a,b = sort(rand(1:1111, 2))
    (a, rand(1:1111), b)
end

# @btime sum(y for x in input for y in myf(x))
@btime sum(Iterators.flatten(Iterators.map(myf, input)))
@btime sum(Flatmap{Int64}(myf, input))
  201.143 μs (3 allocations: 48 bytes)
  233.336 μs (2 allocations: 32 bytes)
  11.110 μs (3 allocations: 48 bytes)
  17.115 μs (2 allocations: 32 bytes)

I don’t know about the code you implemented here, but in general what I’ve observed with flatten and friends in regards to inference is that the compiler really has a hard time with inference some time when the types get huge. That’s why here I had to add type parameters to help inference along and add the requirement to have each iterator return the same element (performance of shrinking test cases was… not good before I did that).

To give some context, when inference tries to find a return type of a function, it has to either typejoin to find a parent type, or create a Union of the given types. This is only efficient up to a limit - at some point, inference just “gives up”, inserts Any with a dynamic lookup and more or less defers the work until some concrete type actually hits a code path. So what do you do? You can help inference along if you assert the type/convert to something you already know. This keeps Any in check and prevents it from propagating outward, making the function locally unstable, but stable from the outside. The trouble is that you can then run into broken conversions/runtime errors if the types unexpectedly don’t match up…

Maybe some of the trouble I had was due to wanting to nest Flatten iterators recursively - inference really doesn’t like that without some help :slight_smile:

It may just be that map is already heavily optimized in regards to its return type(s) and flatten can take full advantage of that. I’m honestly not sure where a combined approach would work/gain performance :person_shrugging: You may want to compare to a non-ideal case for flatten(map(...)) though, e.g. when your iterators are type unstable. Maybe your implementation does better for them?

That’s a great though, maybe there’s something interesting regarding type inference. It’s really such an important topic. At first, though, I just want to understand what happens with known and simple types everywhere. Like in my examples, it’s just appending numeric ranges. I can imagine the compiler might be able to take this flatten(map(, open it up and write something optimal. I want to look at the full code to see if I have any ideas… And right now I seem to be struggling with type inference, actually, but I don’t see any type instability. There’s just these Core.PartialStruct(..., Any[...]) that I don’t know what it means.