Compile-time type filtering from a tuple - is it possible?

Hi all,

I am developing a library where I would like to split some structs into, say, two categories based on their type; this is better shown as a minimal example:

abstract type Vowel end
abstract type Consonant end

struct A <: Vowel end
struct B <: Consonant end
struct C <: Consonant end

function sieve(letters...)
    vowels = tuple((s for s in letters if s isa Vowel)...)
    consonants = tuple((s for s in letters if s isa Consonant)...)

    vowels, consonants
end

sieve(A(), B(), C(), A()) |> typeof

#output
Tuple{Tuple{A}, Tuple{B, C}}

The function above works; if I call it with only consonants, it is type-stable and everything is done at compile-time:

julia> @code_warntype sieve(B(), C())
MethodInstance for sieve(::B, ::C)
  from sieve(letters...) @ Main REPL[6]:1
Arguments
  #self#::Core.Const(sieve)
  letters::Core.Const((B(), C()))
Locals
  #4::var"#4#6"
  #3::var"#3#5"
  consonants::Tuple{B, C}
  vowels::Tuple{}
Body::Tuple{Tuple{}, Tuple{B, C}}
1 ─ %1  = Main.tuple::Core.Const(tuple)
│         (#3 = %new(Main.:(var"#3#5")))
│   %3  = #3::Core.Const(var"#3#5"())
│   %4  = Base.Filter(%3, letters)::Core.Const(Base.Iterators.Filter{var"#3#5", Tuple{B, C}}(var"#3#5"(), (B(), C())))
│   %5  = Base.Generator(Base.identity, %4)::Core.Const(Base.Generator{Base.Iterators.Filter{var"#3#5", Tuple{B, C}}, typeof(identity)}(identity, Base.Iterators.Filter{var"#3#5", Tuple{B, C}}(var"#3#5"(), (B(), C()))))
│         (vowels = Core._apply_iterate(Base.iterate, %1, %5))
│   %7  = Main.tuple::Core.Const(tuple)
│         (#4 = %new(Main.:(var"#4#6")))
│   %9  = #4::Core.Const(var"#4#6"())
│   %10 = Base.Filter(%9, letters)::Core.Const(Base.Iterators.Filter{var"#4#6", Tuple{B, C}}(var"#4#6"(), (B(), C())))
│   %11 = Base.Generator(Base.identity, %10)::Core.Const(Base.Generator{Base.Iterators.Filter{var"#4#6", Tuple{B, C}}, typeof(identity)}(identity, Base.Iterators.Filter{var"#4#6", Tuple{B, C}}(var"#4#6"(), (B(), C()))))
│         (consonants = Core._apply_iterate(Base.iterate, %7, %11))
│   %13 = Core.tuple(vowels, consonants)::Core.Const(((), (B(), C())))
└──       return %13

The return type shows Core.Const(((), (B(), C()))) - and this is quite impressive already! We have all necessary information at compile-time, as the function arguments are collected into a tuple letters::Core.Const((B(), C())); however, once a vowel is introduced, it all becomes type-unstable:

julia> @code_warntype sieve(A(), B(), C())
MethodInstance for sieve(::A, ::B, ::C)
  from sieve(letters...) @ Main REPL[6]:1
Arguments
  #self#::Core.Const(sieve)
  letters::Core.Const((A(), B(), C()))
Locals
  #4::var"#4#6"
  #3::var"#3#5"
  consonants::Tuple{Vararg{Union{A, B, C}}}
  vowels::Tuple{A, Vararg{Union{A, B, C}}}
Body::Tuple{Tuple{A, Vararg{Union{A, B, C}}}, Tuple{Vararg{Union{A, B, C}}}}
1 ─ %1  = Main.tuple::Core.Const(tuple)
│         (#3 = %new(Main.:(var"#3#5")))
│   %3  = #3::Core.Const(var"#3#5"())
│   %4  = Base.Filter(%3, letters)::Core.Const(Base.Iterators.Filter{var"#3#5", Tuple{A, B, C}}(var"#3#5"(), (A(), B(), C())))
│   %5  = Base.Generator(Base.identity, %4)::Core.Const(Base.Generator{Base.Iterators.Filter{var"#3#5", Tuple{A, B, C}}, typeof(identity)}(identity, Base.Iterators.Filter{var"#3#5", Tuple{A, B, C}}(var"#3#5"(), (A(), B(), C()))))
│         (vowels = Core._apply_iterate(Base.iterate, %1, %5))
│   %7  = Main.tuple::Core.Const(tuple)
│         (#4 = %new(Main.:(var"#4#6")))
│   %9  = #4::Core.Const(var"#4#6"())
│   %10 = Base.Filter(%9, letters)::Core.Const(Base.Iterators.Filter{var"#4#6", Tuple{A, B, C}}(var"#4#6"(), (A(), B(), C())))
│   %11 = Base.Generator(Base.identity, %10)::Core.Const(Base.Generator{Base.Iterators.Filter{var"#4#6", Tuple{A, B, C}}, typeof(identity)}(identity, Base.Iterators.Filter{var"#4#6", Tuple{A, B, C}}(var"#4#6"(), (A(), B(), C()))))
│         (consonants = Core._apply_iterate(Base.iterate, %7, %11))
│   %13 = Core.tuple(vowels, consonants)::Tuple{Tuple{A, Vararg{Union{A, B, C}}}, Tuple{Vararg{Union{A, B, C}}}}
└──       return %13

Now we do all type filtering at runtime and return an unknown Tuple{Tuple{A, Vararg{Union{A, B, C}}}, Tuple{Vararg{Union{A, B, C}}}}.

Is there any way to do this filtering at compile-time? In the actual application this would help a lot with type-stability around the function call.

Thanks,
Leonard

This seems to work as you request:

module Letters
  abstract type Letter end

  # Length zero
  filtered(::Type{L}, ::Tuple{}) where {L <: Letter} = ()

  # Length one
  filtered(::Type{L}, letters::T) where {L <: Letter, T <: Tuple{L}} = letters
  filtered(::Type{L}, letters::T) where {L <: Letter, T <: Tuple{Letter}} = ()

  # Length two or more
  filtered(
    ::Type{L}, letters::T,
  ) where {L <: Letter, T <: Tuple{Letter, Letter, Vararg{Letter}}} =
    (
      filtered(L, (first(letters),))...,
      filtered(L, Base.tail(letters))...,
    )
end

const Letter = Letters.Letter
const filtered = Letters.filtered

abstract type Vowel <: Letter end
abstract type Consonant <: Letter end
struct A <: Vowel end
struct B <: Consonant end
struct C <: Consonant end

sieve(letters...) = (
  filtered(Vowel, letters),
  filtered(Consonant, letters),
)

const some_letters = (A(), B(), C(), A(), B(), C(), C(), B(), A())

import Test, JET

# Try out `Letters.filtered`
filtered(Vowel, some_letters)      # (A(), A(), A())
filtered(Consonant, some_letters)  # (B(), C(), B(), C(), C(), B())

# Try out `sieve`
sieve(some_letters...)             # ((A(), A(), A()), (B(), C(), B(), C(), C(), B()))

# Test for some coding mistakes, there should be no warning for our code
Test.detect_unbound_args(Main, recursive = true)
Test.detect_ambiguities(Main, recursive = true)

# Test return type inference, there should be no error
Test.@inferred sieve(some_letters...)

# Test, there should be no warnings
JET.@report_opt filtered(Vowel, some_letters)
JET.@report_opt filtered(Consonant, some_letters)
JET.@report_opt sieve(some_letters...)

Recursion is often used for implementing such type-stable functions manipulating tuples, as in Letters.filtered. In simpler cases than here the ntuple function can be helpful.

I’m curious about why do you need such a function. Perhaps your design is suboptimal or convoluted? Can you give more context?

2 Likes

Actually I think it’s nicer like this, without the unnecessary extra abstract type (Letter):

# Length zero
filtered(::Type{L}, ::Tuple{}) where {L} = ()

# Length one
filtered(::Type{L}, letters::T) where {L, T <: Tuple{L}} = letters
filtered(::Type{L}, letters::T) where {L, T <: Tuple{Any}} = ()

# Length two or more
filtered(::Type{L}, letters::T) where {L, T <: Tuple{Any, Any, Vararg{Any}}} = (
  filtered(L, (first(letters),))...,
  filtered(L, Base.tail(letters))...,
)

abstract type Vowel end
abstract type Consonant end
struct A <: Vowel end
struct B <: Consonant end
struct C <: Consonant end

sieve(letters...) = (
  filtered(Vowel, letters),
  filtered(Consonant, letters),
)

const some_letters = (A(), B(), C(), A(), B(), C(), C(), B(), A())

import Test, JET

# Try out `filtered`
filtered(Vowel, some_letters)      # (A(), A(), A())
filtered(Consonant, some_letters)  # (B(), C(), B(), C(), C(), B())

# Try out `sieve`
sieve(some_letters...)             # ((A(), A(), A()), (B(), C(), B(), C(), C(), B()))

# Test for some coding mistakes, there should be no warning for our code
Test.detect_unbound_args(Main, recursive = true)
Test.detect_ambiguities(Main, recursive = true)

# Test return type inference, there should be no error
Test.@inferred sieve(some_letters...)

# Test, there should be no warnings
JET.@report_opt sieve(some_letters...)
2 Likes

Wow, that was incredibly quick and helpful, thank you! Very elegant solution indeed; I wasn’t using Test and JET before, instead relying on manual @code_warntype or Cthulhu.@descend - another thank you for including the tests, I will be using them from now on.

I’m writing a library where the user may specify multiple types of interactions (each being a struct, defining fairly complex computations) that I want to collect into a tuple, such that they can be unrolled and inlined when evaluating them thousands of times per second. For user convenience, they can be specified in any order, and then sorted internally into the correct context - e.g. Force and Torque are both Interaction subtypes, but must be handled differently.

I couldn’t think of a more convenient and performant design solution until now - though I’m very grateful for any general tips if something springs to mind based on the above (limited, but not too domain-specific) description.

1 Like

I can’t tell for sure without looking at the broader design, of course, but I think that usually better performance is achieved with a single concrete struct type that encompasses both Force and Torque (for more type stability). For example, maybe you could make Interaction a struct and also have a field that’s an enumeration type that represents either force or torque.

There would most probably be some redundancy among the fields of struct Interaction, because it would need to contain all of the fields required by both force and torque, so some fields would be left unused, however this shouldn’t be a problem.

Just something to think about, in case you haven’t already.

If you want to go the route of an enumeration field, SumTypes.jl is :fire::fire:

It automatically does the enumeration stuff and then its unwrapping macro checks to make sure you’ve handled all the cases. Can’t recommend it enough.

2 Likes