Failure to optimize due to type recursion

I find that it is relatively easy to get non-inferred code where the mapfoldl family is involved, because their implementation has so many layers. Eg a simple example is

_length(c::Real) = 1
_length(c::Tuple) = mapfoldl(_length, +, c; init = 0)
_length(c::NamedTuple) = _length(values(c))
_length(c::AbstractArray) = sum(_length, c; init = 0)

x = (a = 1, b = [2.0, 3.0], d = (4, 5f0))
@code_warntype _length(x)

using JET
@report_opt _length(x)

My understanding is that #48059 is meant to address this, but it is currently dormant.

Am I abusing the inference engine with simple code as above? Should I do something differently, eg use generated functions for the tuples?

Just to clarify, I am looking for general advice, the above is an MWE. I routinely write code that fails inference because of deeply nested mapXXX functions. I am wondering if I should give that up, or do it differently.

2 Likes

Just as an example, a heavy-handed fix is in the above case is

@generated function _length(c::T) where {T<:Tuple}
    mapfoldl(i -> :(_length(c[$i])), (a, b) -> :($a + $b), 1:fieldcount(T))
end

Yeah recursive inference is a major issue for writing performant composable code, and comes up a lot!
Even in very simple code:

julia> f(X) = sum(X) do x
                      sum(x) do y
                              y
                      end
              end
f (generic function with 1 method)

julia> @code_warntype f([[1]])
MethodInstance for f(::Vector{Vector{Int64}})
  from f(X) @ Main REPL[1]:1
Arguments
  #self#::Core.Const(f)
  X::Vector{Vector{Int64}}
Locals
  #6::var"#6#8"
Body::Any
1 ─      (#6 = %new(Main.:(var"#6#8")))
β”‚   %2 = #6::Core.Const(var"#6#8"())
β”‚   %3 = Main.sum(%2, X)::Any
└──      return %3

is the shortest example I’m aware of.

1 Like

@aplavin This isn’t about recursion. This is performance of captured variables in closures Β· Issue #15276 Β· JuliaLang/julia Β· GitHub.

No variables are captured in these examples though.
To make it even more clear, try

f(X) = sum(Base.Fix1(sum, identity), X)

same ::Any inference.

3 Likes