# Type inference: are small Unions ok?

I am trying to understand type inference better. My understanding is very basic: knowing the output types from the input types is good, not knowing is bad. I don’t really understand what causes the slowdown.

In any case, I’ve seen people on the internet claim that “small Unions are fine”, i.e. something like

``````julia> inner(x) = x > 0 ? 3.0 : 3
inner (generic function with 1 method)

julia> @code_warntype inner(2.)
Variables
#self#::Core.Const(inner)
x::Float64

Body::Union{Float64, Int64}
1 ─ %1 = (x > 0)::Bool
└──      goto #3 if not %1
2 ─      return 3.0
3 ─      return 3
``````

is to be avoided, but not at all cost. The justification is that julia only has to specialise two possible options, i.e. one for each type.

However, the following example seems to contradict this:

``````julia> wrapper(x::Int) = x > 0 ? rand(2, 3) : Diagonal(rand(2))
wrapper (generic function with 1 method)
julia> wrapper(x::Float64) = x > 0 ? "hello" : :world
wrapper (generic function with 2 methods)
julia> together(x) = wrapper(inner(x))
together (generic function with 1 method)
julia> @code_warntype together(2.0)
Variables
#self#::Core.Const(together)
x::Float64
Body::Any
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}
│   %2 = Main.wrapper(%1)::Any
└──      return %2
``````

Since it looks like the `together` function actually infers to `Any` type, not just the Union of four types as one would expect from that story?

2 Likes

Not an answer, but my experience is also that inference of small unions with more then two types can be fragile and I try to avoid it.

Your `wrapper` has two methods and at compile time it’s unknown which of the two methods will be selected (because of the runtime dependence of the return type of `inner` - note the two seperate `MethodInstance`s):

``````julia> code_warntype(wrapper, (Union{Float64,Int64},))
MethodInstance for wrapper(::Int64)
from wrapper(x::Int64) in Main at REPL[2]:1
Arguments
#self#::Core.Const(wrapper)
x::Int64
Body::Union{Diagonal{Float64, Vector{Float64}}, Matrix{Float64}}
1 ─ %1 = (x > 0)::Bool
└──      goto #3 if not %1
2 ─ %3 = Main.rand(2, 3)::Matrix{Float64}
└──      return %3
3 ─ %5 = Main.rand(2)::Vector{Float64}
│   %6 = Main.Diagonal(%5)::Diagonal{Float64, Vector{Float64}}
└──      return %6

MethodInstance for wrapper(::Float64)
from wrapper(x::Float64) in Main at REPL[14]:1
Arguments
#self#::Core.Const(wrapper)
x::Float64
Body::Union{String, Symbol}
1 ─ %1 = (x > 0)::Bool
└──      goto #3 if not %1
2 ─      return "hello"
3 ─      return :world
``````

Julia could of course just use a naive `Union` of all results, but this would only slow down inference and compilation down the road (larger unions will force functions down the road to be checked for all elements after all), so a tradeoff is made and julia tries to find a minimal type that still gathers all behaviours. In some cases when there’s no common ancestor other than `Any`, it’ll have to fallback to that:

``````julia> typejoin(Union{String, Symbol}, Union{Diagonal{Float64, Vector{Float64}}, Matrix{Float64}})
Any
``````

The reason this is ok is because it makes no difference at runtime - a dynamic runtime dispatch is inserted either way, but falling back to `Any` makes it easier on the compiler because it doesn’t have to keep track of as many things, preventing a combinatorial explosion of possible types. “Classical” inference in static languages would throw compile time errors here instead.

2 Likes

Also, if you absolutely can’t get rid of a type instability (e.g. because it’s a library function and hasn’t been patched yet), you can help inference along with an appropriate conversion method, `convert` or a type assert:

``````julia> together(x) = wrapper(float(inner(x)))
together (generic function with 1 method)

julia> @code_warntype together(2.0)
MethodInstance for together(::Float64)
from together(x) in Main at REPL[27]:1
Arguments
#self#::Core.Const(together)
x::Float64
Body::Union{String, Symbol}
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}
│   %2 = Main.float(%1)::Float64
│   %3 = Main.wrapper(%2)::Union{String, Symbol}
└──      return %3
``````

The explicit conversion with `float` allows inference to assume that the argument to `wrapper` will be a `Float64`, determining the correct method for `wrapper` at compile time.