Following the nice conversation at #44792 I’m trying to investigate the question of whether flatmap
functions might offer performance benefits. I suspect that in general flatten(x)
cannot do better than a flatmap(identity, x)
, because flatmap
solves a larger task. Unless the compiler has such a level of sophistication that it actually can treat a flatten(map(...))
as a flatmap. Which is possible, but if this is the case I would imagine having explicit flatmap
functions somewhere just makes sense.
I also suspect that most of the time this will come down to specializations for different data-structures, and iterators might actually be the least interesting case. I also imagine Julia is already very much optimized for the simple tasks I’m going to investigate, and it won’t be easy to find an optimization opportunity. It’s still probably a great learning opportunity! I definitely want to know if flatmap
functions should be always expected to behave exactly as a flatten(map(...))
.
So I set myself to write a Flatmap
iterator and this is what I came up with for now. The original version still beats my code by a factor of 15% to 55%. (I basically wrote a double-while-loop and “opened it up”… It’s really annoying.)
I guess there’s one thing we might try to offer a benefit, that would be making a faster check at myit = myf(it)
if the iterator is empty. This check would be theoretically faster than calling iterate
. Essentially a flatten(filter(!empty,(map(...)))
? Just one idea…
Anyways, I’m curious to hear. What problems in my clunky code might be keeping it from having the same performance as the original code? Or what optimizations might be happening in the original code that the new one doesn’t benefit from?
The only hint I found it @code_lowered is that we get these type annotations like Core.PartialStruct(..., Any[...])
. Could that be an issue?
using BenchmarkTools
using ProfileView
struct Flatmap{FuncType, Eltype, S}
myf::FuncType
input::S
Flatmap{Eltype}(myf::FuncType, input::S) where {FuncType, Eltype, S} =
new{typeof(myf), Eltype, Core.Typeof(input)}(myf, input)
end
Flatmap(f, initialstate) = Flatmap{Any}(f, initialstate)
Base.eltype(::Type{Flatmap{F, Eltype, S}}) where {F, Eltype, S} = Eltype
Base.eltype(::Type{<:Flatmap{nothing}}) = Any
Base.IteratorEltype(::Type{<:Flatmap{nothing}}) = Base.EltypeUnknown()
Base.IteratorEltype(::Type{<:Flatmap}) = Base.HasEltype()
Base.IteratorSize(::Type{<:Flatmap}) = Base.SizeUnknown()
function Base.iterate(myfinput::Flatmap{F, Eltype, S}) where {F, Eltype, S}
myf = myfinput.myf
input = myfinput.input
itst = iterate(input)
if isnothing(itst)
return nothing
end
itst = itst::Tuple{Any, Any}
it, st = itst
myit = myf(it)
newoutxst = iterate(myit)
while isnothing(newoutxst)
itst = iterate(input, st)
if isnothing(itst)
return nothing
end
it, st = itst
myit = myf(it)
newoutxst = iterate(myit)
end
newout, xst = newoutxst
return newout, (st,myit,xst)
end
function Base.iterate(myfinput::Flatmap{F, Eltype, S}, (st,myit,xst)) where {F, Eltype, S}
@inbounds begin
newoutxst = iterate(myit, xst)
if !isnothing(newoutxst)
newout, xst = newoutxst
return newout, (st,myit,xst)
end
myf = myfinput.myf
input = myfinput.input
newoutxst = nothing
while isnothing(newoutxst)
itst = iterate(input, st)
itst === nothing && return nothing
itst = itst::Tuple{Any, Any}
it, st = itst
myit = myf(it)
newoutxst = iterate(myit)
end
newoutxst = newoutxst
newout, xst = newoutxst
return newout, (st,myit,xst)
end
end
@inline myf((a,b,c)) = a:b:c
input = map(1:1111) do _
# a,b = sort(randn(2))
a,b = randn(2)
(a, rand()*3, b)
end
@btime sum(Iterators.flatten(Iterators.map(myf, input)))
@btime sum(Flatmap{Float64}(myf, input))
input = map(1:1111) do _
a,b = sort(rand(1:1111, 2))
(a, rand(1:1111), b)
end
# @btime sum(y for x in input for y in myf(x))
@btime sum(Iterators.flatten(Iterators.map(myf, input)))
@btime sum(Flatmap{Int64}(myf, input))
201.143 μs (3 allocations: 48 bytes)
233.336 μs (2 allocations: 32 bytes)
11.110 μs (3 allocations: 48 bytes)
17.115 μs (2 allocations: 32 bytes)