Techniques for defining commutative functions

I stumbled into a curious problem that seems like it surely must already been solved by someone before, maybe even somewhere in Base that I’m just not familar with. I figured I’d consult the brain-trust here to see if anyone had any good ideas.

Problem

Say I’ve defined some set of types A, B, C, etc and I want to define a binary/two-argument function f that acts on these and is commutative (i.e. f(x,y) == f(y,x)). There are a large number of possible combinations of argument types and many but not all of them will be implemented. I’m looking for an elegant way to implement this commutativity.

Ideas

Here’s some thoughts I’ve had, but I’m very open to new ideas.

Option 1

For every pair of argument types that I intend to implement, I could explicitly define two methods.

f(x::A, y::B) = ...
f(x::B, y::A) = f(y, x)

f(x::A, y::C) = ...
f(x::C, y::A) = f(y, x)

...

Pro: It works. Not overly complicated.
Con: It requires a second boilerplate method for every implemented method. I have to remember to add both every time I implement a new method. For a large number of possible types, this is a non-trivial number of additional method definitions.

Maybe this could be implemented with a macro?

@commutative f(x::A, y::B) = ...

@commutative f(x::A, y::C) = ...

...

Option 2

I could define a single generic two-argument fallback method that reverses argument order.

f(x::A, y::B) = ...

f(x::A, y::C) = ...

f(x, y) = f(y, x)

Pro: Requires less boilerplate code.
Con: Generates a StackOverflow if f has not been implemented for this (x,y) pair.

I considered adding a hasmethod(f, typeof.(y, x)) check or applicable(f, y, x) in the fallback method with the hope that it would prevent recursion for unimplemented pairs, but this method definition itself would serve for (::Any, ::Any).

Guess, the property you want is commutative, i.e., f(x, y) = f(y, x).
Option 1 seems most straight-forward. As an alternative, you could define an ordering on your types and have the fallback sort its arguments:

ord(::A) = 1
ord(::B) = 2
ord(::C) = 3

f(x, y) = if ord(x) < ord(y); f(x, y) else f(y, x) end
f(x::A, y::B) = "AB"
f(x::A, y::C) = "AC"
...

This would still give a stack overflow if a combination is not defined though.

To prevent this, you could dispatch to an internal method holding the actual implementation:

g(x, y) = if ord(x) < ord(y); _g(x, y) else _g(y, x) end
_g(x::A, y::B) = "AB"
_g(x::A, y::C) = "AC"
...

Overall, a macro is probably a good idea as it would clearly communicate intent without any boilerplate and allows to easily change the implementation by simply redefining the macro.

7 Likes

:man_facepalming: Yes. This is what engineers get for speaking mathematics.

1 Like

To avoid a combinatorial explosion here, one strategy is to define promotion rules that promote arguments to a common type. This is what is done for the built-in arithmetic and numeric operators (although specialized operators are also defined in many cases for performance reasons, e.g. complex + real addition).

2 Likes

I tried to generalize the OP to keep the advice widely applicable, but in my current/specific use case the individual types are different enough that I don’t think promotion to common types will be possible. Essentially, every method being implemented is specific to how those two types interact in particular.

I think I’m leaning toward the macro version, but I don’t have a lot of experience with metaprogramming in Julia. Can anyone provide feedback on my first attempt at this?:

macro commutative(expr)
    # Ensure that the expression is a function definition
    if expr.head != :function
        error("The @commutative macro only applies to function definitions")
    end

    # Get the function name and its arguments
    fname = expr.args[1].args[1]
    args = expr.args[1].args[2:end]

    # Check if the function has exactly two arguments
    if length(args) != 2
        error("A @commutative function must have exactly two arguments")
    end

    # Create a new function definition with swapped arguments
    expr_commuted = Expr(:function, Expr(:call, fname, args[2], args[1]), expr.args[2])

    # Define the commutative methods
    eval(expr)                # f(a::A, b::B)
    eval(expr_commuted)       # f(b::B, a::A)
end

It seems to work under basic testing.

@commutative function f(myint::Int, myfloat::Float64)
    return ( (typeof(myint),   myint),
             (typeof(myfloat), myfloat) )
end

@info f(1, 2.0)
# [ Info: ((Int64, 1), (Float64, 2.0))

@info f(1.0, 2)
# [ Info: ((Int64, 2), (Float64, 1.0))

Macros should not call eval, but instead return an expression. To return multiple definitions, these can be combined into a begin ... end block. Further, you need to escape the expressions such that the function definition is inserted as is, i.e., referring to the scope where the macro is called and not where it was defined:

macro commutative(expr)
    # Ensure that the expression is a function definition
    if expr.head != :function
        error("The @commutative macro only applies to function definitions")
    end

    # Get the function name and its arguments
    fname = expr.args[1].args[1]
    args = expr.args[1].args[2:end]

    # Check if the function has exactly two arguments
    if length(args) != 2
        error("A @commutative function must have exactly two arguments")
    end

    # Create a new function definition with swapped arguments
    expr_commuted = Expr(:function, Expr(:call, fname, args[2], args[1]), expr.args[2])

    # Define regular and commuted methods
    :(begin
         $(esc(expr))
         $(esc(expr_commuted))
      end)
end

To check if everything works as expected, check the macro expansion:

julia> @macroexpand @commutative function f(myint::Int, myfloat::Float64)
           return ( (typeof(myint),   myint),
                    (typeof(myfloat), myfloat) )
       end
quote
    #= REPL[25]:21 =#
    function f(myint::Int, myfloat::Float64)
        #= REPL[26]:1 =#
        #= REPL[26]:2 =#
        return ((typeof(myint), myint), (typeof(myfloat), myfloat))
    end
    #= REPL[25]:22 =#
    function f(myfloat::Float64, myint::Int)
        #= REPL[26]:1 =#
        #= REPL[26]:2 =#
        return ((typeof(myint), myint), (typeof(myfloat), myfloat))
    end
end

Also MacroTools.jl has some convenient utilities for working with function expression, e.g., making it easier to support short as well as long function definitions.

3 Likes