Type inference: are small Unions ok?

I am trying to understand type inference better. My understanding is very basic: knowing the output types from the input types is good, not knowing is bad. I don’t really understand what causes the slowdown.

In any case, I’ve seen people on the internet claim that “small Unions are fine”, i.e. something like

julia> inner(x) = x > 0 ? 3.0 : 3
inner (generic function with 1 method)

julia> @code_warntype inner(2.)
Variables
  #self#::Core.Const(inner)
  x::Float64

Body::Union{Float64, Int64}
1 ─ %1 = (x > 0)::Bool
└──      goto #3 if not %1
2 ─      return 3.0
3 ─      return 3

is to be avoided, but not at all cost. The justification is that julia only has to specialise two possible options, i.e. one for each type.

However, the following example seems to contradict this:

julia> wrapper(x::Int) = x > 0 ? rand(2, 3) : Diagonal(rand(2))
wrapper (generic function with 1 method)
julia> wrapper(x::Float64) = x > 0 ? "hello" : :world
wrapper (generic function with 2 methods)
julia> together(x) = wrapper(inner(x))
together (generic function with 1 method)
julia> @code_warntype together(2.0)
Variables
  #self#::Core.Const(together)
  x::Float64
Body::Any
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}
│   %2 = Main.wrapper(%1)::Any
└──      return %2

Since it looks like the together function actually infers to Any type, not just the Union of four types as one would expect from that story?

2 Likes

Not an answer, but my experience is also that inference of small unions with more then two types can be fragile and I try to avoid it.

Your wrapper has two methods and at compile time it’s unknown which of the two methods will be selected (because of the runtime dependence of the return type of inner - note the two seperate MethodInstances):

julia> code_warntype(wrapper, (Union{Float64,Int64},))                                            
MethodInstance for wrapper(::Int64)                                                               
  from wrapper(x::Int64) in Main at REPL[2]:1                                                     
Arguments                                                                                         
  #self#::Core.Const(wrapper)                                                                     
  x::Int64                                                                                        
Body::Union{Diagonal{Float64, Vector{Float64}}, Matrix{Float64}}                                  
1 ─ %1 = (x > 0)::Bool                                                                            
└──      goto #3 if not %1                                                                        
2 ─ %3 = Main.rand(2, 3)::Matrix{Float64}                                                         
└──      return %3                                                                                
3 ─ %5 = Main.rand(2)::Vector{Float64}                                                            
│   %6 = Main.Diagonal(%5)::Diagonal{Float64, Vector{Float64}}                                    
└──      return %6                                                                                
                                                                                                  
MethodInstance for wrapper(::Float64)                                                             
  from wrapper(x::Float64) in Main at REPL[14]:1                                                  
Arguments                                                                                         
  #self#::Core.Const(wrapper)                                                                     
  x::Float64                                                                                      
Body::Union{String, Symbol}                                                                       
1 ─ %1 = (x > 0)::Bool                                                                            
└──      goto #3 if not %1                                                                        
2 ─      return "hello"                                                                           
3 ─      return :world                                                                            

Julia could of course just use a naive Union of all results, but this would only slow down inference and compilation down the road (larger unions will force functions down the road to be checked for all elements after all), so a tradeoff is made and julia tries to find a minimal type that still gathers all behaviours. In some cases when there’s no common ancestor other than Any, it’ll have to fallback to that:

julia> typejoin(Union{String, Symbol}, Union{Diagonal{Float64, Vector{Float64}}, Matrix{Float64}})
Any                                                                                               

The reason this is ok is because it makes no difference at runtime - a dynamic runtime dispatch is inserted either way, but falling back to Any makes it easier on the compiler because it doesn’t have to keep track of as many things, preventing a combinatorial explosion of possible types. “Classical” inference in static languages would throw compile time errors here instead.

2 Likes

Also, if you absolutely can’t get rid of a type instability (e.g. because it’s a library function and hasn’t been patched yet), you can help inference along with an appropriate conversion method, convert or a type assert:

julia> together(x) = wrapper(float(inner(x)))   
together (generic function with 1 method)       
                                                
julia> @code_warntype together(2.0)             
MethodInstance for together(::Float64)          
  from together(x) in Main at REPL[27]:1        
Arguments                                       
  #self#::Core.Const(together)                  
  x::Float64                                    
Body::Union{String, Symbol}                     
1 ─ %1 = Main.inner(x)::Union{Float64, Int64}   
│   %2 = Main.float(%1)::Float64                
│   %3 = Main.wrapper(%2)::Union{String, Symbol}
└──      return %3                              

The explicit conversion with float allows inference to assume that the argument to wrapper will be a Float64, determining the correct method for wrapper at compile time.