I am trying to understand type inference better. My understanding is very basic: knowing the output types from the input types is good, not knowing is bad. I don’t really understand what causes the slowdown.
In any case, I’ve seen people on the internet claim that “small Unions are fine”, i.e. something like
julia> inner(x) = x > 0 ? 3.0 : 3
inner (generic function with 1 method)
julia> @code_warntype inner(2.)
Variables
#self#::Core.Const(inner)
x::Float64
Body::Union{Float64, Int64}
1 ─ %1 = (x > 0)::Bool
└── goto #3 if not %1
2 ─ return 3.0
3 ─ return 3
is to be avoided, but not at all cost. The justification is that julia only has to specialise two possible options, i.e. one for each type.
However, the following example seems to contradict this:
julia> wrapper(x::Int) = x > 0 ? rand(2, 3) : Diagonal(rand(2))
wrapper (generic function with 1 method)
julia> wrapper(x::Float64) = x > 0 ? "hello" : :world
wrapper (generic function with 2 methods)
julia> together(x) = wrapper(inner(x))
together (generic function with 1 method)
julia> @code_warntype together(2.0)
Variables
#self#::Core.Const(together)
x::Float64
Body::Any
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}
│ %2 = Main.wrapper(%1)::Any
└── return %2
Since it looks like the together function actually infers to Any type, not just the Union of four types as one would expect from that story?
Your wrapper has two methods and at compile time it’s unknown which of the two methods will be selected (because of the runtime dependence of the return type of inner - note the two seperate MethodInstances):
julia> code_warntype(wrapper, (Union{Float64,Int64},))
MethodInstance for wrapper(::Int64)
from wrapper(x::Int64) in Main at REPL[2]:1
Arguments
#self#::Core.Const(wrapper)
x::Int64
Body::Union{Diagonal{Float64, Vector{Float64}}, Matrix{Float64}}
1 ─ %1 = (x > 0)::Bool
└── goto #3 if not %1
2 ─ %3 = Main.rand(2, 3)::Matrix{Float64}
└── return %3
3 ─ %5 = Main.rand(2)::Vector{Float64}
│ %6 = Main.Diagonal(%5)::Diagonal{Float64, Vector{Float64}}
└── return %6
MethodInstance for wrapper(::Float64)
from wrapper(x::Float64) in Main at REPL[14]:1
Arguments
#self#::Core.Const(wrapper)
x::Float64
Body::Union{String, Symbol}
1 ─ %1 = (x > 0)::Bool
└── goto #3 if not %1
2 ─ return "hello"
3 ─ return :world
Julia could of course just use a naive Union of all results, but this would only slow down inference and compilation down the road (larger unions will force functions down the road to be checked for all elements after all), so a tradeoff is made and julia tries to find a minimal type that still gathers all behaviours. In some cases when there’s no common ancestor other than Any, it’ll have to fallback to that:
julia> typejoin(Union{String, Symbol}, Union{Diagonal{Float64, Vector{Float64}}, Matrix{Float64}})
Any
The reason this is ok is because it makes no difference at runtime - a dynamic runtime dispatch is inserted either way, but falling back to Any makes it easier on the compiler because it doesn’t have to keep track of as many things, preventing a combinatorial explosion of possible types. “Classical” inference in static languages would throw compile time errors here instead.
Also, if you absolutely can’t get rid of a type instability (e.g. because it’s a library function and hasn’t been patched yet), you can help inference along with an appropriate conversion method, convert or a type assert:
julia> together(x) = wrapper(float(inner(x)))
together (generic function with 1 method)
julia> @code_warntype together(2.0)
MethodInstance for together(::Float64)
from together(x) in Main at REPL[27]:1
Arguments
#self#::Core.Const(together)
x::Float64
Body::Union{String, Symbol}
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}
│ %2 = Main.float(%1)::Float64
│ %3 = Main.wrapper(%2)::Union{String, Symbol}
└── return %3
The explicit conversion with float allows inference to assume that the argument to wrapper will be a Float64, determining the correct method for wrapper at compile time.