Multiple dispatch with more than 4 methods blocks type inference

Consider the following code:

using InteractiveUtils

struct A end
struct B end

g(r::A, s::A, t::A, u::A) = 1
g(r::B, s::A, t::A, u::A) = 2
g(r::B, s::A, t::B, u::A) = 3
g(r::B, s::B, t::A, u::A) = 4
g(r::B, s::B, t::B, u::A) = 5
g(r::B, s::B, t::B, u::B) = 6

g(r, s, t, u) = 0

function main()
    a = A()
    b = B()
    println()

    for (r, s, t, u) in Iterators.product((a, b), (a, b), (a, b), (a, b))
        res = g(r, s, t, u)
        println(res)
    end
end

main()
@code_warntype main()

As written, res in the main function is inferred as Any. However, when specializations of g are commented out so there are at most 4 methods left, res is inferred as Int64 since each implementation of g returns an Int64.

The issue is that similar code is part of a performance-critical section of some quantum chemistry stuff I have been working on, where res being inferred as Any is problematic.

Therefore my question is: Why does Julia not infer the return type of g when it has more than 4 methods? Can I do something about that? Multiple dispatch is a great feature of Julia, without it, I would have to write this dispatch by hand, but I think it is unfortunate that it cannot feasibly be used in performance-critical sections.

I believe discussion in this post can answer you question:

the idea is that

won’t have 4+ methods to look up, or at least the trade-off is the same at that point.

1 Like

Ahh, faster than me by less than a minute, @jling.

1 Like

I cannot see how that discussion can answer this question. My understanding of that discussion is that using “runtime constructed types” will lead to a similar effect (although it is not clear to me why, and exactly where the limit is. Even the idea of runtime constructed type is strange, because - at least in my understanding - every parametrized type is constructed at runtime).

But there is no runtime generated type in this example, only two empty structs.

+ has 184 methods in a fresh session and it still can be inferred (edit: I have checked, and no, I was wrong, + works similarly, sorry for that)

1 Like

It’s a deliberate choice, and it’s going to be 3 on Julia 1.6: set default max_methods to 3 by JeffBezanson · Pull Request #36208 · JuliaLang/julia · GitHub

I hope you realize that there are two inference failures: one to know which method to call, and one regarding the return type. You’re focusing on the second, but the origin of the second is the first: if Julia knew the exact types that each call would be made with, then this issue would never come up because the number of methods that could be called will always be exactly 1.

You can circumvent the downstream consequence of failure-to-infer the return type just by annotating it:

res = g(r, s, t, u)::Int

But it’s harder to make the call inferrable. If the call really is in a performance-critical section, then

Multiple dispatch is a great feature of Julia, without it, I would have to write this dispatch by hand

you still might want to consider doing that. Method lookup is pretty slow, if it has to be done at runtime. You probably understand this, but if not, think about your loop: on different iterations, different methods will apply. You have a single site where you are calling one of several methods, and hence Julia has to pause, ask what types all the objects are, ask which method applies, and then call it. Since the number of possibilities is 2^4=16, that’s above the union-splitting limit so Julia won’t “unroll” this for you. But if you handle the dispatch manually you can force that. See https://julialang.zulipchat.com/#narrow/stream/225542-helpdesk/topic/Dispatching.20over.20an.20abstract.20iterable.3F/near/209624071 if you need a model.

20 Likes

https://github.com/JuliaCI/BaseBenchmarkReports/blob/0d8398eb91f2fe15b7845efc97b83de87b8fca08/00f1a7a_vs_8ca6e8d/report.md

I’m curious why this benchmark was read as “good improvement”, I feel like it just says this is bout the same? (of course the Plots.jl time was evident, but was that the decisive benchmark?)

I feel like it just says this is bout the… of course the Plots.jl time was evident, but was that the decisive benchmark

Yes. This is really about compile-time performance/latency, so in this case the main role for BaseBenchmarks (which tests runtime performance) is to ask whether anything got a lot worse. If the answer is “no,” then the gains in compile time seem to be coming without major cost to the runtime performance, and so it’s a win.

3 Likes

Here’s one way of writing the dispatch by hand so that res is inferred correctly:

function main2()
    a = A()
    b = B()
    println()

    for (r, s, t, u) in Iterators.product((a, b), (a, b), (a, b), (a, b))
        res = if r isa A
            if s isa A 
                if t isa A
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                else 
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                end
            else # s isa B 
                if t isa A
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                else 
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                end
            end
        else # r isa B
            if s isa A 
                if t isa A
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                else 
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                end
            else # s isa B 
                if t isa A
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                else 
                    if u isa A 
                        g(r, s, t, u)
                    else 
                        g(r, s, t, u)
                    end 
                end
            end
        end
        println(res)
    end
end

The code does become quite repetitive and lengthy. But it seems it would be quite easy to write a macro that does it all. (I bet somebody has done that already.)

2 Likes

You can actually make the first much shorter by not doing full manual dispatch. Since Julia automatically expands places with 3 or less possible methods, you just need to get every call-site to that. This you could ditch the Uber most if/else for all of those branches

1 Like

Just for the fun of it, and if you really wanted to strong-arm inference to unroll the iteration you can use Unrolled.jl. Unrolling is almost crucial here since you are calling a different method on almost every iteration (as Tim described).

You do need to re-engineer the function a bit (tuples are needed because their length is encoded in their type, unlike vectors):

IP=Iterators.product((a, b), (a, b), (a, b), (a, b)) |> collect
IPt=tuple(IP...)

@unroll function main_unroll(IP)
           @unroll for i in IP 
               res = g(i...)
           end
       end
julia> @code_warntype main_unroll(IPt)
Variables
  #self#::Core.Compiler.Const(main_unroll, false)
  IP::Core.Compiler.Const(((A(), A(), A(), A()), (B(), A(), A(), A()), (A(), B(), A(), A()), (B(), B(), A(), A()), (A(), A(), B(), A()), (B(), A(), B(), A()), (A(), B(), B(), A()), (B(), B(), B(), A()), (A(), A(), A(), B()), (B(), A(), A(), B()), (A(), B(), A(), B()), (B(), B(), A(), B()), (A(), A(), B(), B()), (B(), A(), B(), B()), (A(), B(), B(), B()), (B(), B(), B(), B())), false)
  res@_3::Int64
  i@_4::NTuple{4,A}
  res@_5::Int64
  i@_6::Tuple{B,A,A,A}
  res@_7::Int64
  i@_8::Tuple{A,B,A,A}
  res@_9::Int64
  i@_10::Tuple{B,B,A,A}
  res@_11::Int64
  i@_12::Tuple{A,A,B,A}
  res@_13::Int64
  i@_14::Tuple{B,A,B,A}
  res@_15::Int64
  i@_16::Tuple{A,B,B,A}
  res@_17::Int64
  i@_18::Tuple{B,B,B,A}
  res@_19::Int64
  i@_20::Tuple{A,A,A,B}
  res@_21::Int64
  i@_22::Tuple{B,A,A,B}
  res@_23::Int64
  i@_24::Tuple{A,B,A,B}
  res@_25::Int64
  i@_26::Tuple{B,B,A,B}
  res@_27::Int64
  i@_28::Tuple{A,A,B,B}
  res@_29::Int64
  i@_30::Tuple{B,A,B,B}
  res@_31::Int64
  i@_32::Tuple{A,B,B,B}
  res@_33::Int64
  i@_34::NTuple{4,B}

Body::Nothing
1 ─       (i@_4 = Base.getindex(IP, 1))
│         (res@_3 = Core._apply_iterate(Base.iterate, Main.g, i@_4))
│         (i@_6 = Base.getindex(IP, 2))
│         (res@_5 = Core._apply_iterate(Base.iterate, Main.g, i@_6))
│         (i@_8 = Base.getindex(IP, 3))
│         (res@_7 = Core._apply_iterate(Base.iterate, Main.g, i@_8))
│         (i@_10 = Base.getindex(IP, 4))
│         (res@_9 = Core._apply_iterate(Base.iterate, Main.g, i@_10))
│         (i@_12 = Base.getindex(IP, 5))
│         (res@_11 = Core._apply_iterate(Base.iterate, Main.g, i@_12))
│         (i@_14 = Base.getindex(IP, 6))
│         (res@_13 = Core._apply_iterate(Base.iterate, Main.g, i@_14))
│         (i@_16 = Base.getindex(IP, 7))
│         (res@_15 = Core._apply_iterate(Base.iterate, Main.g, i@_16))
│         (i@_18 = Base.getindex(IP, 8))
│         (res@_17 = Core._apply_iterate(Base.iterate, Main.g, i@_18))
│         (i@_20 = Base.getindex(IP, 9))
│         (res@_19 = Core._apply_iterate(Base.iterate, Main.g, i@_20))
│         (i@_22 = Base.getindex(IP, 10))
│         (res@_21 = Core._apply_iterate(Base.iterate, Main.g, i@_22))
│         (i@_24 = Base.getindex(IP, 11))
│         (res@_23 = Core._apply_iterate(Base.iterate, Main.g, i@_24))
│         (i@_26 = Base.getindex(IP, 12))
│         (res@_25 = Core._apply_iterate(Base.iterate, Main.g, i@_26))
│         (i@_28 = Base.getindex(IP, 13))
│         (res@_27 = Core._apply_iterate(Base.iterate, Main.g, i@_28))
│         (i@_30 = Base.getindex(IP, 14))
│         (res@_29 = Core._apply_iterate(Base.iterate, Main.g, i@_30))
│         (i@_32 = Base.getindex(IP, 15))
│         (res@_31 = Core._apply_iterate(Base.iterate, Main.g, i@_32))
│         (i@_34 = Base.getindex(IP, 16))
│         (res@_33 = Core._apply_iterate(Base.iterate, Main.g, i@_34))
│   %33 = Main.nothing::Core.Compiler.Const(nothing, false)
└──       return %33

And you see that it does infer res and the calls. Of course, this puts all the effort into compilation time. If you benchmark, you can see that the computation was mostly done at compilation time (not runtime). General performance sensitive code does not normally follow this pattern, hence the “default” choices made in the language. But as you see, there are constructs that can remediate corner cases.

Hope this helps…

5 Likes

Unfortunately, this is not applicable to the actual code (in contrast to the MWE I posted above), since the number of iterations and the order in which the methods will be called are not known at compile time.

Thank you for your suggestion, dispatching between 7 methods by hand is manageable, but I was also looking to increase the number of types to choose from to 3 which would increase the number of methods to 23. That’s why I had the idea of relying on Julia’s multiple dispatch :smiley:

I don’t know if it helps in your case but using a function barrier may also work:

function main()
  IPt=calculate_IP_tuple_at_runtime()
  main_unroll(IPt)
end

just in case…

1 Like

Well thank you all for your replies, I am really surprised to see this post getting so many well thought-out replies :smiley:

This is a MWE, in the (more complicated) real-world code I have my code wrapped in a main function.

I am aware that the runtime method lookup does not come for free, but getting the type inferred correctly would be a start. In my real-world code, I was seeing tons of memory allocations and up to 75 % GC time, which I believe is caused by res being inferred as Any.

Thank you, this post and the blog entry linked therein certainly helped me to understand the issue.

Now I wonder if there is any possibility to have runtime polymorphism in Julia without hurting performance too much at all? In C++, this is typically achieved through virtual functions. While these are not as powerful as well as having some overhead compared to ‘regular’ functions and preventing inlining, they are certainly better than having no polymorphism at unless you are willing to destroy your performance.

You can achieve the same results as virtual functions using FunctionWrappers.jl. I hacked together a simple interfaces package using it here: https://github.com/rdeits/ConcreteInterfaces.jl

But note that I’ve often gone down this road but never actually ended up using that package in any serious code. Function barriers and https://github.com/tkoolen/TypeSortedCollections.jl have turned out to be more practically useful in my work.

5 Likes

I would profile your real-world case. If the “hot spots” come after that call (and if they are red, i.e., runtime dispatch themselves), then annotating the return type might fix your problem and let you use multiple dispatch to your heart’s content. In contrast, if the hot spot is the call that produces res, then reducing the reliance on runtime dispatch will be your best strategy.

3 Likes