Dispatch on type tuples in runtime?

getvalue(::typeof(sin), ::Float64) = 1
getvalue(::typeof(sin), ::Number) = 2
getvalue(::MyStruct, args...) = 3

Julia has an awesome type system that lets us very flexibly dispatch on function signatures - all the following examples will match at least one method of getvalue() from above:

getvalue(sin, 2.0)   # ==> 1
getvalue(sin, 2)     # ==> 2
getvalue(MyStruct(), "a", "b", "c")  # ==> 3

This works well if I have all input type tuples known in advance. But what if get them in runtime? Of course, I can dynamically add new methods to getvalue, but they will be available only in the next world age. Is there a way to achieve the same behavior but in runtime?

My current attempt is to is to build a trie of types and match them one by one, checking the type itself and all its supertypes. However I continue to run into issues and new use cases, so a more robust approach is appreciated.

But is there a clear rule about what the function must do depending on the types given? It is not clear to me what you want to achieve at the end.

Let me give you some more context.

I’m working on a code tracer (in fact, I’m rewriting Yota’s tracer) which treats all functions as either primitives (meaning “stop tracing here”) or non-primitives (“trace through this function”).

Users of the tracer can mark a function as a primitive by providing its signature type tuple, e.g. (typeof(sin), Number). I then take a tuple of values, e.g. (typeof(sin), Float64) and match it to each available tuple of types. If there’s a match - the function (e.g. sin(::Float64)) is a primitive, otherwise - non-primitive.

1 Like

Couldn’t you directly use which to do the matching for you? (the function, not the macro)

julia> fun(x::Number) = 1
fun (generic function with 1 method)

julia> fun(y) = 2
fun (generic function with 2 methods)

julia> which(fun, (Float64,))
fun(x::Number) in Main at REPL[1]:1
1 Like

Two thoughts:

  1. Base.invokelatest - we’ve discussed this.
  2. There’s nothing stopping you from doing your own dispatch within a method. You could create a dictionary with tuple types as keys for example.
1 Like

Thanks, which() looks like a step further, but not a complete solution yet - there’s a set of available primitives and possibly a wider set of available methods, If, for example, we have 2 functions:

fun(x::Number) = ...
fun(x::Real) = ...

Then which(fun, 2.0) will point to the 2nd one, while we may only have a primitive definition for the 1st one. Maybe there’s a way to get a sorted list of all matching methods for a type signature?

  1. Yeah, using Base.invokelatest() isn’t very user-friendly - the intended use is trace(f, args...; primitives=user_provided_primitives), which, honestly, I wouldn’t expect this to change the world age. But more importantly, whenever we define new methods, we can’t un-define them, i.e. remove something from the list of primitives.
  2. This will only work with concrete types, but if you have arguments of types (typeof(sin), Float64) and in the dict you have only (typeof(sin), Number), the lookup will fail.

I am not sure if this will help you at all, but I am a big fan of functors… rs. And maybe they can help you if you define a function type containing both the function and its type:

julia> struct FunType{T<:Function}
         f::T
         type::Int
       end

julia> (f::FunType)(x) = f.f(x)

julia> isprimitive(f::FunType) = f.type == 1 ? true : false
isprimitive (generic function with 1 method)

The user can define the functions, which can be stored in an array, containing their types:

julia> functions = FunType[ FunType(sin,1), FunType(cos,0) ]
2-element Vector{FunType}:
 FunType{typeof(sin)}(sin, 1)
 FunType{typeof(cos)}(cos, 0)

julia> functions[1](Ď€)
1.2246467991473532e-16

julia> functions[2](Ď€)
-1.0

julia> isprimitive.(functions)
2-element BitVector:
 1
 0


Could users be expected to provide the list of primitives as a function? (i.e. encoding the primitiveness in the methods list)

Something like:

f(x,         y) = 1
f(x::Number, y) = 2
f(x::Real,   y) = 3

function trace(fun, args...; primitives)
    args_str = join(repr.(args), ", ")
    if primitives(fun, typeof.(args)...)
        println("Tracing $fun($args_str)")
    else
        println("Not tracing $fun($args_str)")
    end
end

A first tracing call in which method #1 is considered a primitive:

julia> primitives1(args...) = false;
julia> primitives1(::typeof(f), ::Type{<:Number}, ::Type) = true
primitives1 (generic function with 2 methods)

julia> trace(f, 1, 2; primitives=primitives1)
Not tracing f(1, 2) because it is a primitive

julia> trace(f, "foo", 2; primitives=primitives1)
Tracing f("foo", 2)

And a second one in which method #3 is primitive instead:

julia> primitives2(args...) = false;
julia> primitives2(::typeof(f), ::Type{<:Real}, ::Type) = true
primitives2 (generic function with 2 methods)

julia> trace(f, im, 2; primitives=primitives2)
Tracing f(im, 2)

julia> trace(f, "foo", 2; primitives=primitives2)
Tracing f("foo", 2)
1 Like

Can this be extended to include both - type of the function and types of its arguments? The function type itself isn’t really a problem since I can store them e.g. in a dict, the problem is to match types of the function arguments according to Julia’s method resolution rules.

I’d like to avoid it because functions will need to be defined in compile time, i.e. before runtime. But accept my respect, your solution is very similar to the one in Mjolnir.jl!

You can do something like:

julia> struct MyFunc{F,T}
         f::F
         t::DataType
         MyFunc(f,t) = new{typeof(f),t}(f,t)
       end

julia> (f::MyFunc{F,T})(x::T) where {F,T} = f.f(x)

julia> mysin = MyFunc(sin,Float64)
MyFunc{typeof(sin), Float64}(sin, Float64)

julia> mysin.t # retrieve argument type
Float64

julia> mysin(Float64(Ď€)) # only works for Float64 
1.2246467991473532e-16

julia> mysin(1)
ERROR: MethodError: no method matching (::MyFunc{typeof(sin), Float64})(::Int64)

@dfdx I edited this answer such that its syntax is correct. I don’t know it it helps at all, but at least it works as it is.

with this you can dispatch directly on the type of the defined function:

julia> getvalue(f::MyFunc{typeof(sin),Float64}) = 1
getvalue (generic function with 1 method)

julia> getvalue(f::MyFunc{typeof(sin),Int}) = 2
getvalue (generic function with 2 methods)

julia> mysin1 = MyFunc(sin,Float64);

julia> mysin2 = MyFunc(sin,Int);

julia> getvalue(mysin1)
1

julia> getvalue(mysin2)
2

1 Like

Well, I’d say Julia makes the line quite blurry:

function trace(fun, args...; primitives)
    result = fun(args...)

    args_str = join(repr.(args), ", ")
    print("$fun($args_str) = $result  --  ")
    if primitives(fun, typeof.(args)...)
        println("Not tracing because it is a primitive.")
    else
        println("Tracing call.")
    end
end

function defprimitives(list)
    m = Module()

    for (fun, argtypes...) in list
        args = [:(::Type{<:$t}) for t in argtypes]
        @eval m is_primitive(::$(typeof(fun)), $(args...)) = true
    end

    @eval m is_primitive(args...) = false
end

Building anonymous functions like this (or named functions in anonymous modules) can be done at runtime:

julia> f(x)         = 0 ;
julia> f(x::Number) = 1 ;
julia> f(x::Real)   = 2
f (generic function with 3 methods)

julia> for _ in 1:10
           typ = rand((String, Number))
           print("typ=$typ: ")
           trace(f, 1; primitives=defprimitives([(f, typ)]))
       end
typ=String: f(1) = 2  --  Tracing call.
typ=String: f(1) = 2  --  Tracing call.
typ=Number: f(1) = 2  --  Not tracing because it is a primitive.
typ=Number: f(1) = 2  --  Not tracing because it is a primitive.
typ=String: f(1) = 2  --  Tracing call.
typ=Number: f(1) = 2  --  Not tracing because it is a primitive.
typ=Number: f(1) = 2  --  Not tracing because it is a primitive.
typ=Number: f(1) = 2  --  Not tracing because it is a primitive.
typ=Number: f(1) = 2  --  Not tracing because it is a primitive.
typ=String: f(1) = 2  --  Tracing call.

But maybe I missed something? Or you’d like to avoid the extra compilation introduced at run-time by such techniques?

1 Like
  1. Maybe you could combine Base.invokelatest with @lmiq 's FunType for an easier interface.
  2. If lookup fails, you could iterate over the keys in the Dict and check:
julia> Tuple{typeof(sin),Float64} <: Tuple{typeof(sin),Number}
true

There are tricky cases like subtyping, multiple arguments, varargs, etc., but in general the idea of encoding type tuple into a struct signature and value into a field looks promising. I’ll play around with it, thanks!

The problem with compilation at runtime is that it increases world age. You don’t see it in your example because REPL always acts in the latest world age, but you try to compile and use a new function at the same compilation unit, it will fail:

julia> function foo()
           @eval bar() = println(42)
           bar()
       end
foo (generic function with 1 method)

julia> foo()
ERROR: MethodError: no method matching bar()
The applicable method may be too new: running in world age 29617, while current world is 29618.
Closest candidates are:
  bar() at REPL[3]:2 (method too new to be called from this world context.)
Stacktrace:
 [1] foo()
   @ Main ./REPL[3]:3
 [2] top-level scope
   @ REPL[4]:1

Base.invokelatest() can be used to invoke methods from the future, but it often leads to unexpected results which you (or even worse - your users) learn about at the late stages of development.

1 Like

The problem here is the order of methods to be checked. Imagine our dict looks like this:

(typeof(sin), Number) => 1
(typeof(sin), Real) => 2

Signature (typeof(sin), Float64) will be matched to the first type tuple we encounter, not to the most specific. Is there any way to order type signatures in the order of method resolution?


Let me take a minute to express gratitude to all the commentators - even if the proposed solutions don’t fully match my criteria, they still bring a lot of ideas and move me further. Thank you!

1 Like

Maybe there’s a way to get a sorted list of all matching methods for a type signature?

Something like:

julia> methods(+, Tuple{Int, Any})
# 14 methods for generic function "+":
[1] +(x::T, y::T) where T<:Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8} in Base at int.jl:87
[2] +(c::Union{Int16, Int32, Int64, Int8}, x::BigInt) in Base.GMP at gmp.jl:534
[3] +(c::Union{Int16, Int32, Int64, Int8}, x::BigFloat) in Base.MPFR at mpfr.jl:384
[4] +(x::Number, y::Base.TwicePrecision) in Base at twiceprecision.jl:271
[5] +(::Number, ::Missing) in Base at missing.jl:117
[6] +(x::Number, J::LinearAlgebra.UniformScaling) in LinearAlgebra at /Applications/Julia-1.6.app/Contents/Resources/julia/share/julia/stdlib/v1.6/LinearAlgebra/src/uniformscaling.jl:146
[7] +(x::Real, z::Complex{Bool}) in Base at complex.jl:300
[8] +(x::Real, z::Complex) in Base at complex.jl:312
[9] +(a::Integer, b::Integer) in Base at int.jl:919
[10] +(x::Integer, y::Ptr) in Base at pointer.jl:161
[11] +(y::Integer, x::Rational) in Base at rational.jl:295
[12] +(x::T, y::T) where T<:Number in Base at promotion.jl:396
[13] +(x::Number, y::Number) in Base at promotion.jl:321
[14] +(x::Integer, y::AbstractChar) in Base at char.jl:224

I’m not sure if this will be ordered. It seems ordered, but it’s easy enough to sort the methods:

signatures = getfield.(methods(+, Tuple{Int, Any}).ms, :sig)
sort(signatures, by=<:)

Of course, you can only get a partial order. You could turn this into a directed graph, then traverse that to find methods.

Some code to do that

Forgive the style, I’d been coding for maybe two years when I wrote this.

using LightGraphs
using Combinatorics

"""
Given a type graph, removes all redundant edges (e.g. a path between nodes
already exists.)
"""
function prune_tree(g)
  pruning_order = sortperm(map(length, g.fadjlist), rev=true)
  new_tree = DiGraph(nv(g))
  for node in pruning_order
    all_children = g.fadjlist[node]
    if isempty(all_children)
        new_children = copy(all_children)
    else
        new_children = setdiff(all_children, union(g.fadjlist[all_children]...))
    end
    for child in new_children
      add_edge!(new_tree, node, child)
    end
  end
  return new_tree
end

"""
Creates type lattice
Args:
    types: Array of types to make a graph out of
Returns a DiGraph whose edges denote subtype relationships.
"""
function type_graph(types::AbstractArray)
  idxmapping = IdDict()
  for (idx, i) in enumerate(types)
    idxmapping[i] = idx
  end
  typeedges = Iterators.filter(x->(x[1]!=x[2])&&(x[1]<:x[2]), permutations(types, 2)) |> collect
  edges = map(x->map(k->idxmapping[k]::Int, x), typeedges)
  g = DiGraph(length(types))
  map(x->add_edge!(g, x[2], x[1]), edges)
  g = prune_tree(g)
  return g
end
signatures = getfield.(methods(+).ms, :sig)
lattice = type_graph(signatures)

Which looks like:

You could then insert nodes on this graph as you discover more types.


I think I was trying something a little like this way back in the day. I pulled the code for this from an old repo but most of it won’t work anymore)

2 Likes

@ivirshup sorting methods with <: was a brilliant idea, thank you! Eventually I ended up with the following custom collection which supports type hierarchies as well as Vararg function signatures:

"""
Dict-like data structure which maps function signature to a value.
Unlike real dict, getindex(rsv, sig) returns either exact match, or
closest matching function signature. Example:

    rsv = FunctionResolver{Symbol}()
    rsv[(typeof(sin), Float64)] = :Float64
    rsv[(typeof(sin), Real)] = :Real
    rsv[(typeof(sin), Number)] = :Number
    order!(rsv)                      # Important: sort methods before usage

    rsv[(typeof(sin), Float64)]   # ==> :Float64
    rsv[(typeof(sin), Float32)]   # ==> :Real
"""
struct FunctionResolver{T}
    signatures::Dict{Any, Vector{Pair{DataType, T}}}
    FunctionResolver{T}() where T = new{T}(Dict())
end

function FunctionResolver{T}(pairs::Vector{Pair{S, T} where S}) where T
    rsv = FunctionResolver{T}()
    for (sig, val) in pairs
        rsv[sig] = val
    end
    order!(rsv)
    return rsv
end

Base.show(io::IO, rsv::FunctionResolver) = print(io, "FunctionResolver($(length(rsv.signatures)))")


function Base.setindex!(rsv::FunctionResolver{T}, val::T, sig::Tuple) where T
    fn_typ = sig[1]
    tuple_sig = Tuple{sig...}
    if !haskey(rsv.signatures, fn_typ)
        rsv.signatures[fn_typ] = Pair[]
    end
    push!(rsv.signatures[fn_typ], tuple_sig => val)
    return val
end

function Base.getindex(rsv::FunctionResolver{T}, sig::Tuple) where T
    fn_typ = sig[1]
    if haskey(rsv.signatures, fn_typ)
        tuple_sig = Tuple{sig...}
        for (TT, val) in rsv.signatures[fn_typ]
            if tuple_sig <: TT
                return val
            end
        end
    end
    return nothing
end

function order!(rsv::FunctionResolver)
    for (fn_typ, sigs) in rsv.signatures
        sort!(sigs, lt=(p1, p2) -> p1[1] <: p2[1])
    end
end

Base.haskey(rsv::FunctionResolver, sig::Tuple) = (rsv[sig] !== nothing)
Base.in(sig::Tuple, rsv::FunctionResolver{Bool}) = haskey(rsv, sig)
Base.empty!(rsv::FunctionResolver) = empty!(rsv.signatures)


function test_it()
    rsv = FunctionResolver{Symbol}()
    rsv[(typeof(sin), Vararg)] = :Vararg
    rsv[(typeof(sin), Float64)] = :Float64
    rsv[(typeof(sin), Real)] = :Real
    rsv[(typeof(sin), Number)] = :Number
    order!(rsv)

    @test rsv[(typeof(sin), Float64)] == :Float64
    @test rsv[(typeof(sin), Float32)] == :Real
    @test rsv[(typeof(sin), Float64, Float64)] == :Vararg

    # non-matching signature
    rsv[(typeof(cos), Number)] = :CosineNumber
    @test rsv[(typeof(cos), String)] === nothing
end
1 Like