Why does findfirst(==(T), ...) on a tuple of typed only constant fold for the first?

Consider:

position(::T) where T = findfirst(==(T), (Int, Float64, Char))

You might think this would either constant fold or not.

But actually, it constant-folds for the Int inputs but not for Float64 or Char.
And if I reorder it, then it constant folds for which ever i put first.

julia> @code_typed position(10)
CodeInfo(
1 ─     return 1
) => Int64

julia> @code_typed position(2.5)
CodeInfo(
1 ──       goto #12 if not true
2 ┄─ %2  = φ (#1 => 1, #11 => %24)::Int64
│    %3  = φ (#1 => Int64, #11 => %25)::DataType
│    %4  = φ (#1 => 1, #11 => %26)::Int64
│    %5  = $(Expr(:foreigncall, :(:jl_types_equal), Int32, svec(Any, Any), 0, :(:ccall), :(%3), Float64))::Int32
│    %6  = Core.sext_int(Core.Int64, %5)::Int64
│    %7  = (%6 === 0)::Bool
│    %8  = Base.not_int(%7)::Bool
└───       goto #4 if not %8
3 ──       goto #13
4 ── %11 = (%4 === 3)::Bool
└───       goto #6 if not %11
5 ──       goto #7
6 ── %14 = Base.add_int(%4, 1)::Int64
└───       goto #7
7 ┄─ %16 = φ (#5 => true, #6 => false)::Bool
│    %17 = φ (#6 => %14)::Int64
│    %18 = φ (#6 => %14)::Int64
│    %19 = φ (#5 => true)::Bool
└───       goto #9 if not %16
8 ──       goto #10
9 ── %22 = Base.getfield((Int64, Float64, Char), %17, true)::DataType
└───       goto #10
10 ┄ %24 = φ (#9 => %17)::Int64
│    %25 = φ (#9 => %22)::DataType
│    %26 = φ (#9 => %18)::Int64
│    %27 = φ (#8 => %19, #9 => false)::Bool
│    %28 = Base.not_int(%27)::Bool
└───       goto #12 if not %28
11 ─       goto #2
12 ┄ %31 = Base.nothing::Core.Const(nothing)
└───       goto #13
13 ┄ %33 = φ (#3 => %2, #12 => %31)::Union{Nothing, Int64}
└───       return %33
) => Union{Nothing, Int64}
1 Like

Here are two things to understand this behavior:

  • constant propagation happens at inference time, not in optimization
    (although things might appear easier to be constant-folded after inlining)
  • loop is lowered into two parts 1.) head iteration , and 2) succeeding iterations
    (and loop often prohibits constant-folding)

If you descend into findfirst, it might be clearer why constant folding succeeds only for the first element:

julia> position(::T) where T = findfirst(==(T), (Int, Float64, Char))
position (generic function with 1 method)

julia> using Cthulhu

julia> @descend optimize=false position(42.)
position(::T) where T in Main at REPL[21]:1
Variables
  #self#::Core.Const(position)
  _::Float64

│ ─ %-1  = invoke position(::Float64)::Union{Nothing, Int64}
    @ REPL[21]:1 within `position`
1 ─ %1 = (==)($(Expr(:static_parameter, 1)))::Core.Const(Base.Fix2{typeof(==), DataType}(==, Float64))
│   %2 = Core.tuple(Main.Int, Main.Float64, Main.Char)::Core.Const((Int64, Float64, Char))
│   %3 = Main.findfirst(%1, %2)::Union{Nothing, Int64}
└──      return %3
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [i]nlining costs, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
   %1  = ==(::Type{Float64})::Core.Const(Base.Fix2{typeof(==), DataType}(==, Float64))
 • %3  = < constprop > findfirst(::Core.Const(Base.Fix2{typeof(==), DataType}(==, Float64)),::Core.Const((Int64, Float64, Char)))::Union{Nothing, Int64}
   ↩
findfirst(testf::Function, A) in Base at array.jl:2058
Variables
  #self#::Core.Const(findfirst)
  testf::Core.Const(Base.Fix2{typeof(==), DataType}(==, Float64))
  A::Core.Const((Int64, Float64, Char))
  @_4::Union{Nothing, Tuple{Pair{Int64, DataType}, Int64}}
  @_5::Int64
  a::DataType
  i::Int64

│ ─ %-1  = invoke findfirst(::Fix2{…},::Tuple{…})::Union{Nothing, Int64}
    @ array.jl:2059 within `findfirst`
1 ─ %1  = Base.pairs(A)::Core.Const(Base.Pairs{Int64, DataType, Base.OneTo{Int64}, Tuple{DataType, DataType, DataType}}(1 => Int64, 2 => Float64, 3 => Char))
│         (@_4 = Base.iterate(%1))::Core.Const((1 => Int64, 1))
│   %3  = (@_4::Core.Const((1 => Int64, 1)) === nothing)::Core.Const(false)
│   %4  = Base.not_int(%3)::Core.Const(true)
└──       goto #6 if not %4
2 ┄ %6  = @_4::Tuple{Pair{Int64, DataType}, Int64}
│   %7  = Core.getfield(%6, 1)::Pair{Int64, DataType}
│   %8  = Base.indexed_iterate(%7, 1)::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
│         (i = Core.getfield(%8, 1))::Int64
│         (@_5 = Core.getfield(%8, 2))::Core.Const(2)
│   %11 = Base.indexed_iterate(%7, 2, @_5::Core.Const(2))::Core.PartialStruct(Tuple{DataType, Int64}, Any[DataType, Core.Const(3)])
│         (a = Core.getfield(%11, 1))::DataType
│   %13 = Core.getfield(%6, 2)::Int64
│   @ array.jl:2060 within `findfirst`
│   %14 = (testf)(a)::Bool
└──       goto #4 if not %14
3 ─       return i
    @ array.jl:2061 within `findfirst`
4 ─       (@_4 = Base.iterate(%1, %13))::Union{Nothing, Tuple{Pair{Int64, DataType}, Int64}}
│   %18 = (@_4 === nothing)::Bool
│   %19 = Base.not_int(%18)::Bool
└──       goto #6 if not %19
5 ─       goto #2
    @ array.jl:2062 within `findfirst`
6 ┄       return Base.nothing
Select a call to descend into or ↩ to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [i]nlining costs, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
 • %1  = < constprop > pairs(::Core.Const((Int64, Float64, Char)))::Core.Const(Base.Pairs{Int64, DataType, Base.OneTo{Int64}, Tuple{DataType, DataType, DataType}}(1 => Int64, 2 => Float64, 3 => Char))
   %2  = < constprop > iterate(::Core.Const(Base.Pairs{Int64, DataType, Base.OneTo{Int64}, Tuple{DataType, DataType, DataType}}(1 => Int64, 2 => Float64, 3 => Char)),::Tuple{})::…
   %8  = < constprop > indexed_iterate(::Pair{Int64, DataType},::Core.Const(1))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(2)])
   %11  = < constprop > indexed_iterate(::Pair{Int64, DataType},::Core.Const(2),::Core.Const(2))::Core.PartialStruct(Tuple{DataType, Int64}, Any[DataType, Core.Const(3)])
   %14  = < constprop > Fix2(::DataType)::Bool
   %17  = < constprop > iterate(::Core.Const(Base.Pairs{Int64, DataType, Base.OneTo{Int64}, Tuple{DataType, DataType, DataType}}(1 => Int64, 2 => Float64, 3 => Char)),::Tuple{Int64})::…
   ↩

We can confirm that the first iteration of findfirst is separated from the others, and so the first once has chance to be constant-folded, but the others don’t.
Especially, when constants are propagated to findfirst, if %14 = (testf)(a) can be constant-folded at the first iteration (from the basic block #2), then inference doesn’t need to account for succeeding iterations (which are handled within the basic block #4), and everything is folded. But if the first iteration failed, our inference algorithm will “mix up” the first iteration and all succeeding iterations (technically this process is called “widening”, which ensures inference termination by merging previous iteration index Const(1) and new index Const(2) to form new abstract index Int), and then there is no chance for folding.

6 Likes

Should we define a recursive findfirst implementation for tuples?

julia> _findfirst(f, i, ::Tuple{}) = nothing
_findfirst (generic function with 1 method)

julia> _findfirst(f, i, x) = f(first(x)) ? i : _findfirst(f, i+1, Base.tail(x))
_findfirst (generic function with 2 methods)

julia> myfindfirst(f::F, x::Tuple) where {F} = _findfirst(f, 1, x)
myfindfirst (generic function with 1 method)

julia> position(::T) where {T} = myfindfirst(==(T), (Int, Float64, Char))
position (generic function with 1 method)

julia> @code_typed position(10)
CodeInfo(
1 ─     return 1
) => Int64

julia> @code_typed position(2.4)
CodeInfo(
1 ─     return 2
) => Int64

julia> @code_typed position('a')
CodeInfo(
1 ─     return 3
) => Int64

julia> @code_typed position(1f0)
CodeInfo(
1 ─     return nothing
) => Nothing
4 Likes