Dispatch based on symbol parameter

Is there any way to dispatch directly on a symbol function parameter? I have a function which takes a symbol and returns a different type based on a symbol that is passed into it. In order for this function to have a type stable output (whether explicitly or inferred), it would need to dispatch based on the symbol parameter. I realize this can be done with Val but I was hoping to avoid Val. I was hoping for something that look like this:

f( t::NamedTuple{T,U}, s::Symbol ) = t[s]
@code_warntype f((a=1,b=1.0), :a)

which returns a union. The undesirable but type stable Val version looks like this:

f( t::NamedTuple{T,U}, ::Val{V} ) = t[V]
@code_warntype f((:a=1,:b=1.0), Val(:a))
1 Like

Is this what you intend?

function dispatch(t::NamedTuple{(:a, :b)}, s::Symbol)
    t[s]
end

nt = (a = 1, b = 2)
dispatch(nt, :a) # 1
dispatch(nt, :b) # 2
dispatch(nt, :c) # error

No. The issue with that is that it returns a union which makes code which calls it significantly less effiicient.

Are you sure that it is significantly less efficient?
For example,

nt = (a = 1, b = 2.0)
@code_warntype nt.a  # or getfield(nt, :a)

also shows the same Union{Int64,Float64}. My understanding is that this might get optimized away eventually, due to propagation of constants.

Unions of two or three bitstypes are very performant, and they have been for a while. Unions of two concrete types are reasonable, too.

Well, I’m not sure because this is highly simplified sample code. In the actual code that I’m using, it ends up returning Any and so is significantly less efficient.

Returning Any is not a good approach. What is causing that?

A much larger named tuple.

How many distinct types of values occur in that named tuple?

You might use a few structs that hold values of several types rather than many different types.

I think the confusion here is how @code_warntype works.

@code_warntypes looks at the function that is called and infers based on the input types what happens. Meaning, it will not propagate the information that :a is a constant at compile time, it only assume the input is some Symbol. This is why it shows the Union even for something like nt.a which is definitely type stable.

My brain is too slow to come up with a proper setting to demonstrate it, but essentially, wrap the example into a function and it is fine, like

f(nt, s) = nt[s]
g(nt, x) = f(nt, :a) * x

@code_warntype g(nt, 2)
1 Like

Thanks for that. Unfortunately, in my real code, it is not even returning a Union but I can dig further into why that is.

It seems that when I go to 4 types, it returns Any:

f( t::NamedTuple{U,V}, s::Symbol ) where {U,V} = t[s]
@code_warntype f((a=1,b=1.0,c=1.0f0,d=0x1), :a)

I can use Val in the worst case though I’ll have to put Val almost everywhere in my code. If that’s the only way, thanks to both of you for help. I hadn’t known about constant propagation in julia.

I’m very sorry, but you still using the same misconception of @code_warntype for debugging your code. The point is that @code_warntype basically tries to assert which type comes out if you call f(::@NamedTuple{a::Int64, b::Float64, c::Float32, d::UInt8}, ::Symbol).

This operation is of course not type-stable and the union splitting stops after a while. Using Val here fixes that issue as f(::..., ::Val{:a}) contains the symbol and make the warning disappear.

However, in the setting when you want to use the function, there is the major advantage that you would write something like f(nt, :a), which means that Julia has more information whenever it comes across this line. It is exactly the same reason why nt.a is type-stable in the first place!

Consider this example:

nt = (a=1,b=1.0,c=1.0f0,d=0x1, e=2)

f(nt, s) = nt[s]

function something(nt)
      return f(nt, :a) + f(nt, :e)
end

@code_warntype something(nt)

As you see, it knows for the function f suddenly the correct return types:

1 ─ %1 = Main.f(nt, :a)::Int64
│   %2 = Main.f(nt, :e)::Int64

That is the situation you will be in when actually using the function.

Of course, something like

map( nt[s], (:a, :e) )

is type-unstable and would only be type-stable if you use Val types. But that is hopefully not the situation you are in…

2 Likes

That’s interesting. Thanks. I had assumed that once julia compiles a function for a given concrete type, it never recompiles or reoptimizes that function. I guess that’s a bad assumption. Anyway, in my actual use case, which is much more complex, it does not work out the types.

Note that constant-propagation is the only reason my_tuple[1] and my_struct.a (which call getindex and getproperty respectively) are type-stable. If the index or property name is not a constant in the code, then it will not be type-stable unless the content types form a narrow union.

If constant propagation isn’t happening, you might nudge the compiler with the Base.@constprop macro. For example:

Base.@constprop :aggressive function my_func(a, b)
    #= ... =#
end
1 Like