How to get type stability across methods?

I have a program where I want to dynamically choose a function from a list and have the result be type stable. All the methods involved do return the same type, but it appears that past a certain number of methods the type checker is no longer able to recognize this. For example:

using InteractiveUtils

abstract type S end

struct A <: S end
foo(::A) = true
struct B <: S end
foo(::B) = true
struct C <: S end
foo(::C) = true

function bar(l::Vector{S})
    foo(rand(l))
end

@code_warntype bar([A(), B()])

struct D <: S end
foo(::D) = true

@code_warntype bar([A(), B()])

prints

% julia types.jl
MethodInstance for bar(::Vector{S})
  from bar(l::Vector{S}) @ Main ~/src/nvm/types.jl:12
Arguments
  #self#::Core.Const(Main.bar)
  l::Vector{S}
Body::Bool
1 ─ %1 = Main.foo::Core.Const(Main.foo)
│   %2 = Main.rand::Core.Const(rand)
│   %3 = (%2)(l)::S
│   %4 = (%1)(%3)::Core.Const(true)
└──      return %4

MethodInstance for bar(::Vector{S})
  from bar(l::Vector{S}) @ Main ~/src/nvm/types.jl:12
Arguments
  #self#::Core.Const(Main.bar)
  l::Vector{S}
Body::Any
1 ─ %1 = Main.foo::Core.Const(Main.foo)
│   %2 = Main.rand::Core.Const(rand)
│   %3 = (%2)(l)::S
│   %4 = (%1)(%3)::Any
└──      return %4

So adding a fourth method breaks the type stability. I included a superclass in case there’s a way to leverage that. I also tried including a type annotation at the call site, but the lowered code still has either a call to typeassert or in some cases a branch conditional on the type.

I just saw this thread Type-stability with a vector of FunctionWrappers which suggests using FunctionWrappers.jl or FunctionWrappersWrappers.jl, so I will look into that. But I’m wondering what is the right way to handle this situation, and whether or not there’s a good way to fix this in base Julia.

1 Like

The optimization that saves you for <= 3 different types is called world splitting, which basically says “we don’t know the type of the argument to foo, but since there’s only 3 foo methods, we’ll just run type inference with each method in turn and take the union of the outcomes”. There has to be a limit to how many possible cases will be considered, and currently that limit is set to 3, so your observations agree with the expected behavior.

In your example, you’re dynamically choosing between values of different types and calling a single function on them. However, your post talks about choosing between different functions. The possible resolutions may look different in each case, so can you explain a bit more what your actual use case looks like?

That said, yes, if you’re dynamically choosing between different functions with the same signature and return type, FunctionWrappers.jl or FunctionWrappersWrappers.jl is the way to go, unless you’re ready for a slightly bigger refactoring that moves this branching out of the type domain entirely. (I’d probably recommend the latter if you’re up for it, but it depends a bit on what your situation actually looks like.)

1 Like

To elaborate, I’m talking about changing this:

function pickandrun(functions, x)
    f = rand(functions)
    return f(x)
end

pickandrun([foo, bar, baz], 2.0)

to something like this

function pickandrun(fnames, x)
    fname = rand(fnames)
    return run(fname, x)
end

function run(fname, x)
    if fname == :foo
        return foo(x)
    elseif fname == :bar
        return bar(x)
    elseif fname == :baz
        return baz(x)
    else
        error("No function $fname")
    end
end

pickandrun([:foo, :bar, :baz], 2.0)

It’s a bit more verbose and perhaps less elegant, depending on your taste, but it’s free of type stability pitfalls and doesn’t rely on the dirty tricks used in FunctionWrappers.jl.

So, first I tried having a list of structs containing functions:

struct S
    run::Function
end

function bar(l::Vector{S})
    rand(l).run()
end

before refactoring to this type based dispatch. I did think about doing a big if statement. It feels a little immoral, but it sounds like that’s one of the better options.

In other languages there are tools for specifying the types like this, eg with traits in Rust. I wonder if we could add this to Julia, like we could have an abstract method which has to be implemented for every subtype:

abstract function foo(::S)::Bool end
foo(::A) = true
foo(::B) = 1 # ERROR

If you use

@code_warntype bar(Union{A,B}[A(), B()])

this is type stable in both your cases, so the main thing is: do you know the types you will apply the function to? Maybe not, because you defined D afterwards, but I wanted to clarify

edit: ah, nothing returned by @code_warntype actually for this because of an error, I guess you need also to change the signatures to Vector{Union{...}}

Maybe what you really want is something like the following?

foo(s::S) = foo_inner(s)::Bool  # type assertion
foo_inner(s::A) = true
foo_inner(s::B) = false

One key point here is that your code must necessarily branch at runtime. This isn’t something that can be resolved at compile time. That’s why writing out the big if may be one of the better options: you’re explicitly putting the decision where it belongs, rather than relying on compiler optimizations to turn a dynamic method table lookup (slow and breaks downstream type inference) into the same big if-elseif-else you would have written yourself (fast and type inference friendly).

5 Likes

Another trick is to set

Base.Experimental.@max_methods 4

or more specifically (this is max 4)

Base.Experimental.@max_methods 4 function foo end

but this is an unexported experimental macro so…

2 Likes

I don’t know if raising max_methods has any benefit over this :point_up: type-asserting wrapper

2 Likes
  1. This is specifying that methods more specific than foo(::S) must output Bool. Specificity is a much harder condition to control or detect than subtyping because the multiple positions’s types (a Union of types) don’t guarantee a neat specificity hierarchy even if the individual non-Union types make a neat subtyping hierarchy.
  2. In general, it won’t be feasible for a method definition to throw this error because methods don’t generally have fixed return types or even return types related to the input types via parameters. You could take a page from static typing and force more specific methods to annotate ::Bool or whatever matches the pattern, but you run into the issue with (1) again.

Return type restriction is a feature of statically typed, single-dispatched virtual functions for good reason.

EDIT: I’d probably go with danielwe’s neater approach because you do have the one named function with runtime-varying input types. I was thinking of the other way around, fixing a return type in a 1-method (recall that 1 <= max_methods) wrapper for arbitrary callables over fixed input types. Unlike the similar FunctionWrappers package, improving return type inference this way doesn’t elide the runtime dispatch-associated heap allocations, which is what I made a separate topic to ask about.