# Inference with Tuple{Union} vs. Union{Tuple}

#1

While working on

I hit some peculiar behavior of inference I’d like to get some input about.

Let’s consider the following, which simplifies the situation a bit:

``````int_or_range() = rand() < 0.5 ? 1 : 1:1

function foo()
x = ones(3,3)
i = int_or_range()
is = (i,i)
x[is...]
end
``````

Obviously, this will return either a `Float64` (if `i::Int`) or a `Matrix{Float64}` (if `i::UnitRange{Int}`). Can inference figure this out? Let’s see:

``````julia> @code_warntype foo()
Variables:
#self#::#foo
x::Array{Float64,2}
i::Union{Int64,UnitRange{Int64}}
is::Tuple{Union{Int64,UnitRange{Int64}},Union{Int64,UnitRange{Int64}}}
#temp#@_5::Core.MethodInstance
#temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64}

Body:
begin
\$(Expr(:inbounds, false))
# meta: location array.jl ones 227
# meta: location array.jl ones 225
# meta: location array.jl ones 224
SSAValue(4) = 3
SSAValue(5) = 3
# meta: pop location
# meta: pop location
# meta: pop location
\$(Expr(:inbounds, :pop))
x::Array{Float64,2} = \$(Expr(:invoke, MethodInstance for fill!(::Array{Float64,2}, ::Float64), :(Base.fill!), :(\$(Expr(:foreigncall, :(:jl_alloc_array_2d), Array{Float64,2}, svec(Any,Int64,Int64), Array{Float64,2}, 0, SSAValue(4), 0, SSAValue(5), 0))), :((Base.sitofp)(Float64,1)::Float64))) # line 3:
i::Union{Int64,UnitRange{Int64}} = \$(Expr(:invoke, MethodInstance for int_or_range(), :(Main.int_or_range))) # line 4:
SSAValue(6) = i::Union{Int64,UnitRange{Int64}}
SSAValue(7) = i::Union{Int64,UnitRange{Int64}} # line 5:
unless (SSAValue(7) isa Int64)::Any goto 28
unless (SSAValue(6) isa Int64)::Any goto 22
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::Int64, ::Int64)
goto 44
22:
unless (SSAValue(6) isa UnitRange{Int64})::Any goto 26
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::UnitRange{Int64}, ::Int64)
goto 44
26:
goto 41
28:
unless (SSAValue(7) isa UnitRange{Int64})::Any goto 39
unless (SSAValue(6) isa Int64)::Any goto 33
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::Int64, ::UnitRange{Int64})
goto 44
33:
unless (SSAValue(6) isa UnitRange{Int64})::Any goto 37
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::UnitRange{Int64}, ::UnitRange{Int64})
goto 44
37:
goto 41
39:
goto 41
41:
#temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64} = (Main.getindex)(x::Array{Float64,2},SSAValue(6),SSAValue(7))::Union{Array{Float64,1},Array{Float64,2},Float64}
goto 46
44:
#temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64} = \$(Expr(:invoke, :(#temp#@_5), :(Main.getindex), :(x), SSAValue(6), SSAValue(7)))
46:
return #temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64}
end::Union{Array{Float64,1},Array{Float64,2},Float64}
``````

Unfortunately, inference looses the information that `typeof(is[1])==typeof(is[2])`. But it then actually considers all four cases to come up with a return type that is pretty good!

But maybe we can help inference by restricting the type of `is` to the two possible cases like so:

``````function bar()
x = ones(3,3)
i = int_or_range()
is = (i,i)::Union{Tuple{Int,Int},Tuple{UnitRange{Int},UnitRange{Int}}}
x[is...]
end
``````

Let’s see…

``````julia>  @code_warntype bar()
Variables:
#self#::#bar
x::Array{Float64,2}
i::Union{Int64,UnitRange{Int64}}
is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}}

Body:
begin
\$(Expr(:inbounds, false))
# meta: location array.jl ones 227
# meta: location array.jl ones 225
# meta: location array.jl ones 224
SSAValue(4) = 3
SSAValue(5) = 3
# meta: pop location
# meta: pop location
# meta: pop location
\$(Expr(:inbounds, :pop))
x::Array{Float64,2} = \$(Expr(:invoke, MethodInstance for fill!(::Array{Float64,2}, ::Float64), :(Base.fill!), :(\$(Expr(:foreigncall, :(:jl_alloc_array_2d), Array{Float64,2}, svec(Any,Int64,Int64), Array{Float64,2}, 0, SSAValue(4), 0, SSAValue(5), 0))), :((Base.sitofp)(Float64,1)::Float64))) # line 3:
i::Union{Int64,UnitRange{Int64}} = \$(Expr(:invoke, MethodInstance for int_or_range(), :(Main.int_or_range))) # line 4:
is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}} = (Core.typeassert)((Core.tuple)(i::Union{Int64,UnitRange{Int64}},i::Union{Int64,UnitRange{Int64}})::Tuple{Union{Int64,UnitRange{Int64}},Union{Int64,UnitRange{Int64}}},Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}})::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}} # line 5:
return (Core._apply)(Main.getindex,(Core.tuple)(x::Array{Float64,2})::Tuple{Array{Float64,2}},is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}})::Any
end::Any
``````

Now instead of considering four cases, two cases would suffice. Unfortunately, that does not happen. Note the couter-intuitive behavior: Although the bound on the type of `is` is tighter, the bound on the type of `getindex(x, is...)` is looser. I don’t claim to understand anything about how inference works, but this surprised me. Is there some (more or less) easy tweaking of inference that could improve this case? Or is there another way I could write that code to get the inference to find the return type `Union{Array{Float64,2},Float64}`? (Note that in the actual use case, the element types of the tuple are not all the same, but still the type of the length n tuple is one of n+1 possibilities (`Union{Tuple...}`), while inference only can limit this to 2^n possibilities (`Tuple{Union...}`).

Any feedback to explain the situation and/or improve the code appreciated!

#2

Is there some (more or less) easy tweaking of inference that could improve this case?

If you’re interested in learning more, there’s certainly some improvements that can be done to make this case much better. Right now, the union separation only happens for one specific case (https://github.com/JuliaLang/julia/blob/master/base/reflection.jl#L408), and doesn’t happen for the `tuple()` constructor case (https://github.com/JuliaLang/julia/blob/3fbe8305f90187d885fea7046d5a9419a6bef8f8/base/inference.jl#L948-L954) or the getfield case (`getfield_tfunc`).

Thus, one valuable straightforward that you could probably cut your teeth on would be to abstract that union splitting logic and move it to a higher level in the inference pipeline, such that all of the inference primitives (“tfuncs”) can generically make use of this divide-and-conquer approach.

#3

You actually made me believe that inference might be approachable even by us mere mortals. So I’ve spent some time trying to get a grip on things. Unfortunately, your breadcrumbs didn’t really lead me anywhere. BUT I think I have nailed down at least part of the problem, and that is `abstract_apply` / `precise_container_types`.

The `getindex` is invoked with `_apply`. But for a `Union` argument type, `precise_container_types` returns `nothing` and `abstract_apply` infers this like `getindex(::Any...)`, which doesn’t bring us very far. Patching `precise_container_types` to effectively treat a `Union` of `Tuple`s of the same lengths like a `Tuple` of `Union`s (which will, in general, be a broader type, of course) indeed makes the return type of `bar` be inferred to same as `foo`. The typed code is different, as indeed no union splitting happens, but for now I’m happy to have the return type improved. (Actually I’m pretty happy to have modified inference without making anything explode.)

I’ll try to improve this further, helpful input still welcome.

#4

That sounds like a good summary. It sounds like you’ve found the right spot in the code too. For the `getindex(precise_container_types)` case, the wider type will give the same answer. For the more general case, I think you could split the union at a higher level and express this as:

``````abstract_apply(t::Union) =
tmerge( abstract_apply(t.a), abstract_apply(t.b) )
``````

Glad to know I could be of some encouragement in your learning. Starting with specific changes like this one is how I learned my way around that code also.

If you’re interested in how our inference code works at a higher level, I made a short animation showing how it collects type information at each step: http://juliacomputing.com/blog/2016/04/04/inference-convergence.html#basic-algorithm. It’s not necessary to understand that part of the algorithm to make significant and beneficial contributions to the base cases, so don’t be overwhelmed by it. I didn’t write much about the tfunc primitives, since my primary goal was to write about how it computes the final answer.