Function with unordered parameter list


#1

First, a bit of background.

I am currently trying to implement a function which takes as parameters a Particle object, as well as three ParticleType objects, and attempts to decay the initial particle into three particles of the correct type. The prototype of this function is the following:

# Simulate the decay P0 → P1 P2 P3
function decay(P0::Particle{ParticleType}, X1::ParticleType, X2::ParticleType, X3::ParticleType)
    # QFT wizardry
    # Generate the actual daughter particles P1, P2, P3
    P1, P2, P3
end

Since the physics of the decay depend on the exact nature of all four particles, this function is implemented using a set of methods taking as parameters more specific objects (subtypes of Particle and ParticleType), e.g.:

function decay(P0::Particle{<:Hadron}, X1::Hadron, X2::ChargedLepton, X3::AntiNeutrino)
    # Actual generation of the daughter particles
    P1, P2, P3
end

Now the actual question.

During a simulation, the function decay() may be called with the last three parameters (the types of the daughter particles) given in any order. The physics remain the same, so it is unnecessary to reimplement the body of the method for each permutation of its arguments. However, the returned Particle objects must be permuted accordingly.

Of course, this could be done through brute force, by defining by hand methods for each permutation and having them call the actual implementation, e.g.:

function decay(P0::Particle{<:Hadron}, X2::ChargedLepton, X1::Hadron, X3::AntiNeutrino)
    P1, P2, P3 = decay(P0, X1, X2, X3)
    P2, P1, P3
end

# Do this for each possible permutation...

This quickly becomes tedious, so I am looking for a generic way to do this automatically.

To sum it up, the question is the following. How to implement in Julia a function F with the property that for all permutations Π, F(X) = Y ⟹ F(Π[X]) = Π[Y] ? Or, in other words, how to implement a function which commutes with an arbitrary permutation of its arguments ?

And please tell me if this is a complete abuse of the type system :slight_smile:


#2

You could use a generated function for all the permutations. Suppose you write _decay with a given argument order, then have a generated decay call it, which sorts out the permutation and rearrangement.


#3

I think you should try find a non combinatorical solution. Julia will specialize the functions on all these permutations and generate quite a lot of code.

What is the content of each particle type? Are they really that heterogeneous that they all have to be different types or are you using different types mostly to exploit dispatch as a convenience? Using dispatch can indeed be convenient but abusing it too much can lead to overspecialization (and slower code if it is type unstable).


#4

With the approach I suggested above, couldn’t one use @noinline _decay for the inner function?


#5

Thanks for your suggestions !

I will look at generated functions and also see if I can come up with a simpler design, e.g. by sorting the arguments before calling the function.

A ParticleType is essentially just a tag (singleton), and a Particle is a ParticleType plus a 4-vector describing its momentum.

I will try to limit the number of particle types to a minimum by pooling those which behave similarly into a common type. But the fact is that the physics of the decay may be very different depending on the types of the particles involved (because of their different interactions), so it seems to me that this situation is better handled by defining a method for each distinct physical process.

This definitely seems like a good idea, since the cost of one extra method call should usually be small compared to the work done by the method.

I guess I will try your different suggestions and see what works best in my case.


#6

I was kind of intrigued by this problem, and was curious to see if the sorting could be done completely at compile time. I’m certainly not saying that this is the right solution (and the Julia 0.6.1 compiler can almost, but not quite handle it), but I hope my attempt to torture the compiler a bit will at least amuse some people and maybe spark some discussion.

First, some setup:

using Base.Test
abstract type ParticleType end
struct ChargedLepton <: ParticleType end
struct Hadron <: ParticleType end
struct AntiNeutrino <: ParticleType end

Next, I defined a function that determines the first argument of a given type (this is inspired by StaticArrays’ first_static). It returns a tuple containing the first match (or nothing if there is no match), as well as all non-matches (discards):

# start with an empty `discards` tuple:
@inline function first_type_match(::Type{T}, ts...) where T
    first_type_match(T, (), ts...)
end

# not a match; add first argument to discards and recurse:
@inline function first_type_match(::Type{T}, discards::Tuple, t1, ts...) where T
    first_type_match(T, tuple(discards..., t1), ts...)
end

# match: return
@inline function first_type_match(::Type{T}, discards::Tuple, t1::T, ts...) where T
    (t1, tuple(discards..., ts...))
end

# no match and done processing all arguments: return
@inline function first_type_match(::Type{T}, discards::Tuple) where T
    (nothing, discards)
end

Note: the @inlines might not be completely necessary. Here is this function in action:

@test begin
    match, discards = first_type_match(Hadron, Hadron(),  ChargedLepton())
    match == Hadron() && discards == (ChargedLepton(),)
end

@test begin
    match, discards = first_type_match(Hadron, ChargedLepton(), Hadron(), AntiNeutrino())
    match == Hadron() && discards == (ChargedLepton(), AntiNeutrino())
end

@test begin
    match, discards = first_type_match(ChargedLepton, AntiNeutrino(), Hadron())
    match == nothing && discards == (AntiNeutrino(), Hadron())
end

@test begin
    match, discards = first_type_match(ChargedLepton)
    match == nothing && discards == ()
end

Note that all the work is done at compile time:

julia> @code_warntype first_type_match(Hadron, ChargedLepton(), Hadron(), AntiNeutrino())

Variables:
  #self# <optimized out>
  #unused# <optimized out>
  ts <optimized out>

Body:
  begin 
      return (Hadron(), (ChargedLepton(), AntiNeutrino()))
  end::Tuple{Hadron,Tuple{ChargedLepton,AntiNeutrino}}

I will use this function as a building block for sorting by type. To do so, we first need to specify a particle type ordering:

first_particle_type() = Hadron
next_particle_type(::Type{Hadron}) = ChargedLepton
next_particle_type(::Type{ChargedLepton}) = AntiNeutrino

Now we can define our sort_particle_types function:

# start by looking for the first particle type with an empty results tuple
sort_particle_types(ts::ParticleType...) = _sort_particle_types(first_particle_type(), (), ts...)

const ParticleTypeTuple = Tuple{Vararg{<:ParticleType,N} where N}

# recursively build the results tuple:
@inline function _sort_particle_types(::Type{T}, result::ParticleTypeTuple, ts::ParticleType...) where T<:ParticleType
    match, discards = first_type_match(T, ts...)
    _sort_particle_types(handle_match(T, result, match)..., discards...)
end

# return the results tuple once all arguments have been processed:
@inline _sort_particle_types(::Type{<:ParticleType}, result::ParticleTypeTuple) = result

where handle_match determines the next ParticleType type to look for and updates the results:

@inline function handle_match(::Type{T}, result::ParticleTypeTuple, match::Void) where T<:ParticleType
    # no match, try the next particle type
    (next_particle_type(T), result)
end

@inline function handle_match(::Type{T}, result::ParticleTypeTuple, match::ParticleType) where T<:ParticleType
    # reset to first particle type, append match to results
    (first_particle_type(), tuple(result..., match))
end

Now, this almost works:

>julia @code_warntype sort_particle_types(ChargedLepton(), Hadron(), AntiNeutrino())

Variables:
  #self# <optimized out>
  ts <optimized out>

Body:
  begin 
      return (Hadron(), ChargedLepton(), AntiNeutrino())
  end::Tuple{Hadron,ChargedLepton,AntiNeutrino}

so the compiler has really figured out the answer already. The problem is surprisingly that you can’t actually call the function on 0.6.1, because

sort_particle_types(ChargedLepton(), Hadron(), AntiNeutrino())

somehow results in a StackOverflowError in inference (despite the answer already being in the code_warntype above): https://gist.github.com/tkoolen/8c3eb36aba92da194e84725c2eab18f6.

I tried this on the latest Julia master as well, and the function returns, but the code_warntype shows that inference has given up:

julia> @code_warntype sort_particle_types(ChargedLepton(), Hadron(), AntiNeutrino())
Variables:
  ts<optimized out>

Body:
  begin
      return $(Expr(:invoke, MethodInstance for _sort_particle_types(::Type{Hadron}, ::Tuple{Hadron,ChargedLepton}, ::AntiNeutrino, ::Vararg{AntiNeutrino,N} where N), :(Main._sort_particle_types), Hadron, (Hadron(), ChargedLepton()), :($(QuoteNode(AntiNeutrino())))))::Tuple{Vararg{ParticleType,N} where N}
  end::Tuple{Vararg{ParticleType,N} where N}

#8

@tkoolen This is simply brilliant ! I am still trying to wrap my head around your code, and I doubt that I will ever use it in production (I don’t want to be killed by my colleagues !), but your solution made my day :smile:

Actually, part of the reason why I posted this question to begin with, instead of just refactoring my code, was to see what was possible with Julia’s type system and compile-time logic. And I can say that I am not disappointed !


#9

Actually, I found that the StackOverflowError doesn’t appear when sort_particle_types is called from another function:

julia> foo() = sort_particle_types(ChargedLepton(), Hadron(), AntiNeutrino());

julia> foo()
(Hadron(), ChargedLepton(), AntiNeutrino())

#10

See https://github.com/JuliaLang/julia/issues/24860.


#11

Thanks for reporting this bug :bug:


#12

(from my github issue comment):
It looks like simply changing all of the @inlines to Base.@pures made everything inferable on nightly.

After the Base.@pure change, performance of inference (time it takes to run the @code_warntype line) also seems to be better than on 0.6.1.