Function specialization based on values, not just types

I’m working on Markov Chains where a lot of recursive functions have a mathematical parametric equation with a base case, simplified to

function f(xs, 1)
     @info "Base case"
end

function f(xs, val)
      @info "General case val > 1"
end

To my knowledge, I can’t specialize on value , only on type. I want to avoid this

function f(xs, x)
     if x <= 1
           @info "handle base case"
     end
     @info "handle generic case"
end

While the above seems trivial, in practice the actual functions quickly become combinatoric with multiple parameter base cases, ranges, etc.

What I tried:

  • I can define an enum with the base case values
@enum basecase B=1

function f(xs, x::basecase)
     @info "basecase"
end

function f(xs, x) # Generic
     @info "Not base"
end

I think I can do @specialize with ranges, e.g. 1:1 would be the base case range, and 2:? would be the general case.

My key objective is to keep the code as clean and elegant as possible, and a close as possible to the mathematical function specifications, while leveraging the type system for safety (and performance).

For example, the enum example fails because f(1, 1) will not call the Enum specialization, and f(1, 0) is invalid, but not caught.

This is related to C++'s constant value type specialization C++ Template Specialization with Constant Value - Stack Overflow

Any help/pointers/suggestions are much appreciated.

This is related to keyword specialization, and from reading I believe Julia already specializes in the background, but this would not allow me (I think?) to define function types by their default values (the base case)?

1 Like

See Val, which brings a value into the type domain:

help?> Val
search: Val Real eval real

  Val(c)

  Return Val{c}(), which contains no run-time data. Types like this can be used to pass the information between functions through the value c, which must be an isbits value or a Symbol. The intent of this
  construct is to be able to dispatch on constants directly (at compile time) without having to test the value of the constant at run time.

  Examples
  ≡≡≡≡≡≡≡≡

  julia> f(::Val{true}) = "Good"
  f (generic function with 1 method)
  
  julia> f(::Val{false}) = "Bad"
  f (generic function with 2 methods)
  
  julia> f(Val(true))
  "Good"
3 Likes

Example:

julia> fib(::Val{0}) = 1
fib (generic function with 1 method)

julia> fib(::Val{1}) = 1
fib (generic function with 2 methods)

julia> fib(::Val{V}) where V = fib(Val(V-1)) + fib(Val(V-2))
fib (generic function with 3 methods)

julia> fib(n) = fib(Val(n))
fib (generic function with 4 methods)

julia> fib.(0:10)
11-element Vector{Int64}:
  1
  1
  2
  3
  5
  8
 13
 21
 34
 55
 89
5 Likes

Thanks!

Note however that if the value isn’t known at compile time, using Val(n) will be a run-time dispatch and will often perform worse than just using a branch within your function.

5 Likes

NB: there’s nothing special about Val, it’s just a struct with a single type parameter but no fields:

The special thing about Val is that this struct is ment for this purpose, and use of it thus signals what the intent of the programmer is.

5 Likes

Sure. In the same way, it could make more sense to use my package TypeDomainNaturalNumbers.jl, rather than use Val.

sure, or GitHub - SciML/Static.jl: Static types useful for dispatch and generated functions.

2 Likes

@specialize is not relevant here. All it does is force specialization (on the type), but that’s the default anyway.

My package TypeDomainNaturalNumbers.jl provides exactly what you’re asking for, more so than Val or the Static.jl package. In particular, it provides types like PositiveInteger, or IntegerGreaterThanOne, so it’s possible to distinguish between the general case (1 < x) and the invalid case (x < 1) purely via dispatch. The example code above could look, for example, like this:

using TypeDomainNaturalNumbers

function f(xs, n::Integer)
    m = TypeDomainInteger(n)
    f(xs, m)
end

function f(xs, ::typeof(TypeDomainInteger(1)))
     @info "Base case"
end

function f(xs, ::IntegerGreaterThanOne)
      @info "General case val > 1"
end

# Instead of using `TypeDomainInteger` here, it's also possible to be more explicit and use `Union{typeof(TypeDomainInteger(0)), NegativeInteger}`
function f(xs, ::TypeDomainInteger)
    @info "Invalid case"
end

Results in:

julia> f(:unused, 0)
[ Info: Invalid case

julia> f(:unused, 1)
[ Info: Base case

julia> f(:unused, 2)
[ Info: General case val > 1
1 Like
julia> @time f([], 10000)
[ Info: General case val > 1

271.034517 seconds (5.32 M allocations: 252.547 MiB, 0.02% gc time, 64.56% compilation time)

this just seems like such a footgun to me relative to stuff like Val.

3 Likes

That’s really cool, as it would get me closer to the equations, will check it out!

1 Like

I would have to try this but in principle if the cost is in compile time, I can pay for that with precompile, thanks for spotting this though, will test

2 Likes

That’s good to know, I can force compile time base cases and use precompile workloads, runtime would be a heavy cost

See also this section of the Julia documentation: Types with values-as-parameters.

1 Like

Some notes:

  1. To clarify, what you show amounts to the cost of constructing a type domain number of a certain, huge, value for the first time. The intended way is for this cost to be paid at compilation time, not at run time.
  2. This takes just 1.3 seconds on nightly Julia for me, as opposed to your timing of 271 seconds. Some relevant PRs are already marked for backporting.
  3. This huge cost doesn’t have to be paid each time a different type domain number gets constructed, a lot is saved between such calls:
    julia> @time f([], 10000)
    [ Info: General case val > 1
      1.290434 seconds (4.68 M allocations: 224.334 MiB, 14.58% gc time, 90.77% compilation time: 34% of which was recompilation)
    
    julia> @time f([], 10001)
    [ Info: General case val > 1
      0.026583 seconds (9.70 k allocations: 502.594 KiB, 61.49% compilation time)
    
  4. To be clear, the design of the types in TypeDomainNaturalNumbers.jl is such that larger numbers have more complicated types, causing extra work for the compiler. This is a necessary trade off for being able to represent number sets properly in the type system, so that the subsetting and subtyping relations match. So using huge numbers is not really intended usage for TypeDomainNaturalNumbers.jl. What constitues huge will vary between versions of Julia, presumably it will keep improving together with the Julia compiler.

So, your package implements natural numbers in the type domain via the Peano construction.

I do understand why this construction is necessary in languages like scala, haskell, idris, coq, lean, etc, where static typing plays the role of a proof system, and you must stay withing the carefully set bounds of it.

I do not understand why this construction is helpful in julia, where you could simply do something like added(::Val{N}, ::Val{M}) where {N,M} = Val{N+M}(), and even have generated functions available, i.e. where you are permitted to “shell out” in order to determine return types.

Could you explain that design choice?

1 Like

Have you checked out the thread here, that’s already discussed above, including in the message you’re replying to. In particular:

Additional advantages of the inductive design:

  • The uniqueness of the representation allows implementing == as simply ===. A Static.jl-like design must instead forward to == on the wrapped value, or, alternatively, assume that the wrapped value is of type Int. This ties in to other issues of a Static.jl-like design, stemming from the lack of type safety.
  • The original motivation for having a type domain implementation of the natural numbers with an inductive design was operating on heterogeneous collections of constant length, specifically sorting tuples. This now happens to be implemented in this Julia PR: support sorting tuples by nsajko · Pull Request #56425 · JuliaLang/julia · GitHub. Try reimplementing it using a Static.jl-like design for the type domain integers, and you’ll see much worse compiler performance. Alternatively, try out my old package TupleSorting.jl. The issue is presumably that doing any operation on a Static.jl-like design requires moving values out of and the back into the type domain, which is certainly very taxing on the compiler when it happens at each step of recursion. In contrast, incrementing or decrementing a type domain natural number (with natural_successor, natural_predecessor) are trivial operations.

Does this help?

1 Like