Why was the original Union return type parsed as type Any?

I have a question about the following code. It can be manually inferred that the return type of ttest must be Union{Nothing,ClassA,ClassB,ClassC}, and I even annotate the type of the return value, c, to the compiler. But the @code_warntype still gives Any type.

  1. Why does this happen? I think the return type should be inferred as Union{Nothing,ClassA,ClassB,ClassC} rather than Any.
  2. How can I solve this type instability? (Though @report_opt of JET package did not report type instability.)

Thanks for your time.
Following is the code

using JET
using Test

abstract type ClassABC end

struct ClassA <: ClassABC
    a::NTuple{2,Int}
end
ClassA(id::Int) = ClassA((0, id))

struct ClassB <: ClassABC
    a::NTuple{2,Int}
end
ClassB(id::Int) = ClassB((0, id))
function findclassbindex(a::Vector{Int}, b::Int)
    idx = findfirst(x -> b ∈ x, x for x ∈ a)
    if isnothing(idx)
        return nothing
    else
        return ClassB(idx)
    end
end

struct ClassC <: ClassABC
    a::NTuple{2,Int}
end
ClassC(id::Int) = ClassC((0, id))
function findclasscindex(a::Vector{Int}, b::Int)
    idx = findfirst(x -> b ∈ x, x for x ∈ a)
    if isnothing(idx)
        return nothing
    else
        return ClassC(idx)
    end
end

function ttest(a::Vector{Int}, b::Int)
    c::Union{Nothing,ClassA,ClassB,ClassC} = nothing
    idx = findfirst(id == b for id ∈ a)
    if !isnothing(idx)
        if idx ≤ 8
            c = ClassA(idx)
        else
            c = findclassbindex(a[9:12], b)
            if isnothing(c)
                c = findclasscindex(a[13:15], b)
            end
        end
    end
    return c
end

a = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
b = 5
ttest(a,b)
b = 10
ttest(a,b)
b = 14
ttest(a,b)

@code_warntype ttest(a,b)   # Body::Any

@report_opt ttest(a,b)  # Passed
@report_call ttest(a,b) # Passed
@inferred ttest(a,b)        # Error: return type not match Any

It seems that when the number of Union types exceeds 3, the return type is parsed to Any. So, we should reduce the number of unions.

function anytype(a)
    if a < 1
        return Float64(5)
    elseif a > 5
        return Int32(5)
    else
        if a >3
            return Int16(2)
        else
            return nothing
        end
    end
end
@code_warntype anytype(2)

function uniontype(a)
    if a < 1
        return Float64(5)
    elseif a > 5
        return Int32(5)
    else
        return nothing
    end
end
@code_warntype uniontype(2)

Yes 4 is the magic number for union splitting. If you want to have efficient code for larger unions, consider using DynamicalSumTypes.jl :slight_smile:

2 Likes