While working on
https://github.com/JuliaLang/julia/pull/20154
I hit some peculiar behavior of inference I’d like to get some input about.
Let’s consider the following, which simplifies the situation a bit:
int_or_range() = rand() < 0.5 ? 1 : 1:1
function foo()
x = ones(3,3)
i = int_or_range()
is = (i,i)
x[is...]
end
Obviously, this will return either a Float64
(if i::Int
) or a Matrix{Float64}
(if i::UnitRange{Int}
). Can inference figure this out? Let’s see:
julia> @code_warntype foo()
Variables:
#self#::#foo
x::Array{Float64,2}
i::Union{Int64,UnitRange{Int64}}
is::Tuple{Union{Int64,UnitRange{Int64}},Union{Int64,UnitRange{Int64}}}
#temp#@_5::Core.MethodInstance
#temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64}
Body:
begin
$(Expr(:inbounds, false))
# meta: location array.jl ones 227
# meta: location array.jl ones 225
# meta: location array.jl ones 224
SSAValue(4) = 3
SSAValue(5) = 3
# meta: pop location
# meta: pop location
# meta: pop location
$(Expr(:inbounds, :pop))
x::Array{Float64,2} = $(Expr(:invoke, MethodInstance for fill!(::Array{Float64,2}, ::Float64), :(Base.fill!), :($(Expr(:foreigncall, :(:jl_alloc_array_2d), Array{Float64,2}, svec(Any,Int64,Int64), Array{Float64,2}, 0, SSAValue(4), 0, SSAValue(5), 0))), :((Base.sitofp)(Float64,1)::Float64))) # line 3:
i::Union{Int64,UnitRange{Int64}} = $(Expr(:invoke, MethodInstance for int_or_range(), :(Main.int_or_range))) # line 4:
SSAValue(6) = i::Union{Int64,UnitRange{Int64}}
SSAValue(7) = i::Union{Int64,UnitRange{Int64}} # line 5:
unless (SSAValue(7) isa Int64)::Any goto 28
unless (SSAValue(6) isa Int64)::Any goto 22
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::Int64, ::Int64)
goto 44
22:
unless (SSAValue(6) isa UnitRange{Int64})::Any goto 26
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::UnitRange{Int64}, ::Int64)
goto 44
26:
goto 41
28:
unless (SSAValue(7) isa UnitRange{Int64})::Any goto 39
unless (SSAValue(6) isa Int64)::Any goto 33
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::Int64, ::UnitRange{Int64})
goto 44
33:
unless (SSAValue(6) isa UnitRange{Int64})::Any goto 37
#temp#@_5::Core.MethodInstance = MethodInstance for getindex(::Array{Float64,2}, ::UnitRange{Int64}, ::UnitRange{Int64})
goto 44
37:
goto 41
39:
goto 41
41:
#temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64} = (Main.getindex)(x::Array{Float64,2},SSAValue(6),SSAValue(7))::Union{Array{Float64,1},Array{Float64,2},Float64}
goto 46
44:
#temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64} = $(Expr(:invoke, :(#temp#@_5), :(Main.getindex), :(x), SSAValue(6), SSAValue(7)))
46:
return #temp#@_6::Union{Array{Float64,1},Array{Float64,2},Float64}
end::Union{Array{Float64,1},Array{Float64,2},Float64}
Unfortunately, inference looses the information that typeof(is[1])==typeof(is[2])
. But it then actually considers all four cases to come up with a return type that is pretty good!
But maybe we can help inference by restricting the type of is
to the two possible cases like so:
function bar()
x = ones(3,3)
i = int_or_range()
is = (i,i)::Union{Tuple{Int,Int},Tuple{UnitRange{Int},UnitRange{Int}}}
x[is...]
end
Let’s see…
julia> @code_warntype bar()
Variables:
#self#::#bar
x::Array{Float64,2}
i::Union{Int64,UnitRange{Int64}}
is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}}
Body:
begin
$(Expr(:inbounds, false))
# meta: location array.jl ones 227
# meta: location array.jl ones 225
# meta: location array.jl ones 224
SSAValue(4) = 3
SSAValue(5) = 3
# meta: pop location
# meta: pop location
# meta: pop location
$(Expr(:inbounds, :pop))
x::Array{Float64,2} = $(Expr(:invoke, MethodInstance for fill!(::Array{Float64,2}, ::Float64), :(Base.fill!), :($(Expr(:foreigncall, :(:jl_alloc_array_2d), Array{Float64,2}, svec(Any,Int64,Int64), Array{Float64,2}, 0, SSAValue(4), 0, SSAValue(5), 0))), :((Base.sitofp)(Float64,1)::Float64))) # line 3:
i::Union{Int64,UnitRange{Int64}} = $(Expr(:invoke, MethodInstance for int_or_range(), :(Main.int_or_range))) # line 4:
is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}} = (Core.typeassert)((Core.tuple)(i::Union{Int64,UnitRange{Int64}},i::Union{Int64,UnitRange{Int64}})::Tuple{Union{Int64,UnitRange{Int64}},Union{Int64,UnitRange{Int64}}},Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}})::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}} # line 5:
return (Core._apply)(Main.getindex,(Core.tuple)(x::Array{Float64,2})::Tuple{Array{Float64,2}},is::Union{Tuple{Int64,Int64},Tuple{UnitRange{Int64},UnitRange{Int64}}})::Any
end::Any
Now instead of considering four cases, two cases would suffice. Unfortunately, that does not happen. Note the couter-intuitive behavior: Although the bound on the type of is
is tighter, the bound on the type of getindex(x, is...)
is looser. I don’t claim to understand anything about how inference works, but this surprised me. Is there some (more or less) easy tweaking of inference that could improve this case? Or is there another way I could write that code to get the inference to find the return type Union{Array{Float64,2},Float64}
? (Note that in the actual use case, the element types of the tuple are not all the same, but still the type of the length n tuple is one of n+1 possibilities (Union{Tuple...}
), while inference only can limit this to 2^n possibilities (Tuple{Union...}
).
Any feedback to explain the situation and/or improve the code appreciated!