Would it be possible to infer `Val{::Bool)` as `Union{Val{true}, Val{false}}`?

Consider the function

julia> f(x::Bool) = Val(x)
f (generic function with 1 method)

julia> @code_typed f(true)
CodeInfo(
1 ─ %1 = Core.apply_type(Base.Val, x)::Type{Val{_A}} where _A
│   %2 = %new(%1)::Val
└──      return %2
) => Val

However, we know that the return type may only be one of Val{true} or Val{false}. Splitting the result manually (or using ValSplit.jl) seems to achieve the intended result:

julia> g(x::Bool) = x ? Val(true) : Val(false)
g (generic function with 1 method)

julia> @code_typed g(true)
CodeInfo(
1 ─     goto #3 if not x
2 ─     return $(QuoteNode(Val{true}()))
3 ─     return $(QuoteNode(Val{false}()))
) => Union{Val{true}, Val{false}}

Since Bool is known to have only two values, I wonder if such an inference seems worth it to avoid dynamic dispatch?

1 Like

It appears this would be as simple as defining a method

Val(x::Bool) = x ? Val{true}() : Val{false}()

This exceptional case is simple to define and I wouldn’t expect negative repercussions elsewhere. But it would be the first specialization of Val so it should get some debate.

I was originally skeptical of what difference this might make, since Vals are mostly used for static branching (via dispatch) and branching with 2 outcomes can often be inferred well (don’t we do some sort of union-splitting on functions with few methods?). However, I made this benchmark and here are some (somewhat idealized, since the target functions are trivial) possible results of this proposal:

# Julia v1.8.0

ValBool(x::Bool) = x ? Val{true}() : Val{false}()
getvaluetypestable(::Val{true}) = 1
getvaluetypestable(::Val{false}) = 0
getvaluetypeunstable(::Val{true}) = Val{1}()
getvaluetypeunstable(::Val{false}) = Val{0}()

using BenchmarkTools

arg = 1:1000;
@btime map(_->getvaluetypestable(Val(true)),$arg); # fully inferred baseline
#  571.186 ns (1 allocation: 7.94 KiB)
@btime map(_->getvaluetypestable(Val(rand(Bool))),$arg);
#  158.000 μs (1 allocation: 7.94 KiB)
@btime map(_->getvaluetypeunstable(Val(rand(Bool))),$arg);
#  158.900 μs (2 allocations: 7.98 KiB)
@btime map(_->getvaluetypestable(ValBool(true)),$arg); # fully inferred baseline
#  574.713 ns (1 allocation: 7.94 KiB)
@btime map(_->getvaluetypestable(ValBool(rand(Bool))),$arg);
#  1.170 μs (1 allocation: 7.94 KiB)
@btime map(_->getvaluetypeunstable(ValBool(rand(Bool))),$arg);
#  1.940 μs (2 allocations: 7.98 KiB)

Looks like this could be a nice win for unstable Val(::Bool). Obviously, one should strive for static inference where possible but in some cases it isn’t.

3 Likes

Sorry for the screenshot, but I’m on my phone and couldn’t copy/paste from Tryjulia

So I figured I’d try and see if I could make this work with @enums. Short answer: yes. no. (see bottom.)

Code which creates specialized methods of Val for type-stable(ish) dynamic dispatch on enums:

@generated Val(var"#"::T) where T<:Base.Enums.Enum = let ifex(enum)=:(if var"#"≡$enum; Val{$enum}() end),
        enums = instances(T), ret = ifex(enums[1]), expr = ret
    for enum = enums[2:end];  expr = last(push!(expr.args, ifex(enum)))  end
    expr.head, expr.args = expr.args[end].head, expr.args[end].args
    ret
end

Sample code which utilizes this:

julia> using BenchmarkTools
       @enum Foo a b c
       foo(x) = foo(Val(Foo(x)), x)
       foo(::Val{a}, x) = x+1
       foo(::Val{b}, x) = x
       foo(::Val{c}, x) = x-1
       println("Before:")
       @btime foo(rand(0:2))
       @generated Val(var"#"::T) where T<:Base.Enums.Enum = let ifex(enum)=:(if var"#"≡$enum; Val{$enum}() end),
               enums = instances(T), ret = ifex(enums[1]), expr = ret
           for enum = enums[2:end];  expr = last(push!(expr.args, ifex(enum)))  end
           expr.head, expr.args = expr.args[end].head, expr.args[end].args
           ret
       end
       println("After:")
       @btime foo(rand(0:2));
Before:
  80.372 ns (1 allocation: 16 bytes)
After:
  11.712 ns (0 allocations: 0 bytes)

Here’s a plot showing how the performance of this relates to how many elements are in the enumeration:

code used to generate this

Note: I had to close and re-open the Julia process each time.

using BenchmarkTools
begin local globcnt = 0
testenum(n) = let Foo=Symbol(:Foo, globcnt+=1), ids=ntuple(n->Symbol(:foo, globcnt+=1), n); quote
    @enum $Foo $(ids...)
    foo(x) = foo(Val($Foo(x)), x)
    $((:(foo(::Val{$(id)}, x) = x+1) for id in ids)...)
    println("Timing before/after specializing Val (n = $($n))")
    @btime foo(rand(0:$(length(ids)-1)))
    @generated Val(var"#"::T) where T<:Base.Enums.Enum = let ifex(enum)=:(if var"#"≡$enum; Val{$enum}() end),
            enums = instances(T), ret = ifex(enums[1]), expr = ret
        for enum = enums[2:end];  expr = last(push!(expr.args, ifex(enum)))  end
        expr.head, expr.args = expr.args[end].head, expr.args[end].args
        ret
    end
    @btime foo(rand(0:$(length(ids)-1)))
end end |> eval
end
testenum(4)

Edit:

Looks like I was benefitting from odd execution order; if you place the Val specialization before the @enum declaration, then you run into world age problems.