Inference with Tuple{Union} vs. Union{Tuple}

While working on
https://github.com/JuliaLang/julia/pull/20154
I hit some peculiar behavior of inference I’d like to get some input about.

Let’s consider the following, which simplifies the situation a bit:

int_or_range() = rand() < 0.5 ? 1 : 1:1

function foo()
    x = ones(3,3)
    i = int_or_range()
    is = (i,i)
    x[is...]
end

Obviously, this will return either a Float64 (if i::Int) or a Matrix{Float64} (if i::UnitRange{Int}). Can inference figure this out? Let’s see:

julia> @code_warntype foo()
Variables:
  #self#::#foo
  x::Array{Float64,2}
  i::Union{Int64,UnitRange{Int64}}
  is::Tuple{Union{Int64,UnitRange{Int64}},Union{Int64,UnitRange{Int64}}}
  #temp#@_5::Core.MethodInstance
  #temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64}

Body:
  begin 
      $(Expr(:inbounds, false))
      # meta: location array.jl ones 227
      # meta: location array.jl ones 225
      # meta: location array.jl ones 224
      SSAValue(4) = 3
      SSAValue(5) = 3
      # meta: pop location
      # meta: pop location
      # meta: pop location
      $(Expr(:inbounds, :pop))
      x::Array{Float64,2} = $(Expr(:invoke, MethodInstance for fill!(::Array{Float64,2}, ::Float64), :(Base.fill!), :($(Expr(:foreigncall, :(:jl_alloc_array_2d), Array{Float64,2}, svec(Any,Int64,Int64), Array{Float64,2}, 0, SSAValue(4), 0, SSAValue(5), 0))), :((Base.sitofp)(Float64,1)::Float64))) # line 3:
      i::Union{Int64,UnitRange{Int64}} = $(Expr(:invoke, MethodInstance for int_or_range(), :(Main.int_or_range))) # line 4:
      SSAValue(6) = i::Union{Int64,UnitRange{Int64}}
      SSAValue(7) = i::Union{Int64,UnitRange{Int64}} # line 5:
      unless (SSAValue(7) isa Int64)::Any goto 28
      unless (SSAValue(6) isa Int64)::Any goto 22
      #temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::Int64, ::Int64)
      goto 44
      22: 
      unless (SSAValue(6) isa UnitRange{Int64})::Any goto 26
      #temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::UnitRange{Int64}, ::Int64)
      goto 44
      26: 
      goto 41
      28: 
      unless (SSAValue(7) isa UnitRange{Int64})::Any goto 39
      unless (SSAValue(6) isa Int64)::Any goto 33
      #temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::Int64, ::UnitRange{Int64})
      goto 44
      33: 
      unless (SSAValue(6) isa UnitRange{Int64})::Any goto 37
      #temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::UnitRange{Int64}, ::UnitRange{Int64})
      goto 44
      37: 
      goto 41
      39: 
      goto 41
      41: 
      #temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64} = (Main.getindex)(x::Array{Float64,2},SSAValue(6),SSAValue(7))::Union{Array{Float64,1},Array{Float64,2},Float64}
      goto 46
      44: 
      #temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64} = $(Expr(:invoke, :(#temp#@_5), :(Main.getindex), :(x), SSAValue(6), SSAValue(7)))
      46: 
      return #temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64}
  end::Union{Array{Float64,1},Array{Float64,2},Float64}

Unfortunately, inference looses the information that typeof(is[1])==typeof(is[2]). But it then actually considers all four cases to come up with a return type that is pretty good!

But maybe we can help inference by restricting the type of is to the two possible cases like so:

function bar()
    x = ones(3,3)
    i = int_or_range()
    is = (i,i)::Union{Tuple{Int,Int},Tuple{UnitRange{Int},UnitRange{Int}}}
    x[is...]
end

Let’s see…

julia>  @code_warntype bar()
Variables:
  #self#::#bar
  x::Array{Float64,2}
  i::Union{Int64,UnitRange{Int64}}
  is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}}

Body:
  begin 
      $(Expr(:inbounds, false))
      # meta: location array.jl ones 227
      # meta: location array.jl ones 225
      # meta: location array.jl ones 224
      SSAValue(4) = 3
      SSAValue(5) = 3
      # meta: pop location
      # meta: pop location
      # meta: pop location
      $(Expr(:inbounds, :pop))
      x::Array{Float64,2} = $(Expr(:invoke, MethodInstance for fill!(::Array{Float64,2}, ::Float64), :(Base.fill!), :($(Expr(:foreigncall, :(:jl_alloc_array_2d), Array{Float64,2}, svec(Any,Int64,Int64), Array{Float64,2}, 0, SSAValue(4), 0, SSAValue(5), 0))), :((Base.sitofp)(Float64,1)::Float64))) # line 3:
      i::Union{Int64,UnitRange{Int64}} = $(Expr(:invoke, MethodInstance for int_or_range(), :(Main.int_or_range))) # line 4:
      is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}} = (Core.typeassert)((Core.tuple)(i::Union{Int64,UnitRange{Int64}},i::Union{Int64,UnitRange{Int64}})::Tuple{Union{Int64,UnitRange{Int64}},Union{Int64,UnitRange{Int64}}},Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}})::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}} # line 5:
      return (Core._apply)(Main.getindex,(Core.tuple)(x::Array{Float64,2})::Tuple{Array{Float64,2}},is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}})::Any
  end::Any

Now instead of considering four cases, two cases would suffice. Unfortunately, that does not happen. Note the couter-intuitive behavior: Although the bound on the type of is is tighter, the bound on the type of getindex(x, is...) is looser. I don’t claim to understand anything about how inference works, but this surprised me. Is there some (more or less) easy tweaking of inference that could improve this case? Or is there another way I could write that code to get the inference to find the return type Union{Array{Float64,2},Float64}? (Note that in the actual use case, the element types of the tuple are not all the same, but still the type of the length n tuple is one of n+1 possibilities (Union{Tuple...}), while inference only can limit this to 2^n possibilities (Tuple{Union...}).

Any feedback to explain the situation and/or improve the code appreciated!

Is there some (more or less) easy tweaking of inference that could improve this case?

If you’re interested in learning more, there’s certainly some improvements that can be done to make this case much better. Right now, the union separation only happens for one specific case (https://github.com/JuliaLang/julia/blob/master/base/reflection.jl#L408), and doesn’t happen for the tuple() constructor case (https://github.com/JuliaLang/julia/blob/3fbe8305f90187d885fea7046d5a9419a6bef8f8/base/inference.jl#L948-L954) or the getfield case (getfield_tfunc).

Thus, one valuable straightforward that you could probably cut your teeth on would be to abstract that union splitting logic and move it to a higher level in the inference pipeline, such that all of the inference primitives (“tfuncs”) can generically make use of this divide-and-conquer approach.

You actually made me believe that inference might be approachable even by us mere mortals. So I’ve spent some time trying to get a grip on things. Unfortunately, your breadcrumbs didn’t really lead me anywhere. BUT I think I have nailed down at least part of the problem, and that is abstract_apply / precise_container_types.

The getindex is invoked with _apply. But for a Union argument type, precise_container_types returns nothing and abstract_apply infers this like getindex(::Any...), which doesn’t bring us very far. Patching precise_container_types to effectively treat a Union of Tuples of the same lengths like a Tuple of Unions (which will, in general, be a broader type, of course) indeed makes the return type of bar be inferred to same as foo. The typed code is different, as indeed no union splitting happens, but for now I’m happy to have the return type improved. (Actually I’m pretty happy to have modified inference without making anything explode.)

I’ll try to improve this further, helpful input still welcome.

That sounds like a good summary. It sounds like you’ve found the right spot in the code too. For the getindex(precise_container_types) case, the wider type will give the same answer. For the more general case, I think you could split the union at a higher level and express this as:

abstract_apply(t::Union) =
    tmerge( abstract_apply(t.a), abstract_apply(t.b) )

Glad to know I could be of some encouragement in your learning. Starting with specific changes like this one is how I learned my way around that code also.

If you’re interested in how our inference code works at a higher level, I made a short animation showing how it collects type information at each step: http://juliacomputing.com/blog/2016/04/04/inference-convergence.html#basic-algorithm. It’s not necessary to understand that part of the algorithm to make significant and beneficial contributions to the base cases, so don’t be overwhelmed by it. I didn’t write much about the tfunc primitives, since my primary goal was to write about how it computes the final answer.

2 Likes