Cool idea: define a function conditionally. Even more powerful than traits

I cooked up a little macro that allows you to define a function conditionally. This condition can be totally arbitrary, the sky’s the limit. The only requirement is that the condition must return True or False. Unless I’m mistaken, this allows for much more flexibility than traits, but requires no changes to base. Anyway, let me know what you think. I could turn it into a fully fledged package or submit it as a PR to Base depending on what people think.

module DefineWhen

using MacroTools

import Base: !, &, |

abstract type TypedBool end
"""
    struct True
Typed `true`
"""
struct True <: TypedBool end
export True
"""
    struct False
Typed `false`
"""
struct False <: TypedBool end
export False

@inline !(::False) = True()
@inline !(::True) = False()
@inline (&)(::False, ::False) = False()
@inline (&)(::False, ::True) = False()
@inline (&)(::True, ::False) = False()
@inline (&)(::True, ::True) = True()
@inline (|)(::False, ::False) = False()
@inline (|)(::False, ::True) = True()
@inline (|)(::True, ::False) = True()
@inline (|)(::True, ::True) = True()

function define(source, if_code)
    if !@capture if_code (if condition_
        function afunction_(arguments__)
            lines__
        end
    else
        fallback_
    end)
        throw(ArgumentError("Requires an if statement and a function call"))
    end
    new_function_name = gensym(afunction)
    condition_variable = gensym("condition")
    Expr(:block,
        source,
        Expr(:function,
            Expr(:call, afunction, arguments...),
            Expr(:block, source,
                Expr(:call, new_function_name, condition,
                    QuoteNode(condition), arguments...
                )
            )
        ),
        source,
        Expr(:function,
            Expr(:call, new_function_name, :(::$True), condition_variable,
                arguments...
            ),
            Expr(:block, Expr(:meta, :inline), bodylines...)
        ),
        source,
        Expr(:function,
            Expr(:call, new_function_name, :(::$False), condition_variable,
                arguments...
            ),
            Expr(:block,
                source,
                fallback
            )
        )
    )
end

"""
    @define if condition_
        function afunction_(arguments__)
            lines__
        end
    else
        fallback_
    end

Only define `afunction` as `lines` when `condition` is [`True`](@ref); if
`condition` is [`False`], `fallback`. `condition` should be inferrable based
only on the types of the `arguments`.

\```jldoctest
julia> using DefineWhen; using Test: @inferred;

julia> can(::typeof(getindex), ::AbstractArray, ::Vararg{Int}) = True();

julia> can(::typeof(getindex), arguments...) = False();

julia> @define if can(getindex, something, index...)
            function my_getindex(something, index...)
                getindex(something, index...)
            end
        else
            error("Can't getindex \$something")
        end;

julia> @inferred my_getindex(1:2, 1)
1

julia> my_getindex(nothing, 1)
ERROR: ErrorException: Can't getindex nothing
[...]
\```
"""
macro define(if_code)
    esc(define(__source__, if_code))
end

export @define

end
1 Like

Errr, how is this different from a branch in the function? That seems exactly like what the macro is expanded into albeit slower.

Note that dispatch is not a magic. It’s basically never faster than a branch.

Note that I’m omitting the part traits and dispatch is different from branch, which is allowing different methods to be defined separately which allows the function to be extended. This does not seem to have that property since it’s just a branch.

I’m sorry I don’t understand…can you explain more? There no actual branch: the if disappears after @macroexpand

This is what results from @macroexpand for the doctest example:

quote
    #= /home/brandon/Desktop/test.jl:94 =#
    function my_getindex(something, index...)
        #= /home/brandon/Desktop/test.jl:94 =#
        var"##my_getindex#432"(can(getindex, something, index...), $(QuoteNode(:(can(getindex, something, index...)))), something, index...)
    end
    #= /home/brandon/Desktop/test.jl:94 =#
    function var"##my_getindex#432"(::Main.DefineWhen.True, var"##condition#433", something, index...)
        $(Expr(:meta, :inline))
        getindex(something, index...)
    end
    #= /home/brandon/Desktop/test.jl:94 =#
    function var"##my_getindex#432"(::Main.DefineWhen.False, var"##condition#433", something, index...)
        #= /home/brandon/Desktop/test.jl:94 =#
        begin
            #= /home/brandon/Desktop/test.jl:99 =#
            error("Can't getindex $(something)")
        end
    end
end

Right, that’s why I asked how is it different from branch since it’s basicallly a branch but potentially slower. I didn’t say you have a branch (which would have been much better, see below.)

What you have is basically the same as

f() = f2(cond ? True() : False())
f2(::True) = g1()
f2(::False) = g2()

and that’s strictly no faster than

f() = cond ? g1() : g2()

And no it’s not about you having to have a branch in condition_. What you have is entirely equivalent to a branch on can(getindex, something, index...) === True().


Also,

There no actual branch

What’s exactly the bad part of this. Again,

The point is that you’ve implemented a branch. But it is hidden (which is fine) and slower (which is not particularly good). It pushes things to dispatch carrying the potential performance penalty of it without taking any advantage of dispatch.

2 Likes

But if the results of can(getindex, something, index...) can be inferred based only on the argument types, shouldn’t it be zero cost?

That’s why I’ve been very carefully saying strictly no faster and pentential penalty.

I mean, sure if the condition is entirely inferable it would be as good as a branch. If the condititon is not inferable it would be much worse than a branch. There’s no case this has any benefit.

Note that in general, a lot of the optimizations on type inference are basically on how to to turn dispatch into branches. It makes it easier to use type unstable code when it’s unavoidable but it’s by no mean a license to dump information into types unnecessarily to stress out the compiler even if it is able to turn the code back into branches.

2 Likes

Ok, I think I understand. So it would be strictly faster just to write:

function my_getindex(something, index...)
    if can(getindex, something, index...) isa True
        getindex(something, index...)
    else
        error("Can't getindex $something")
    end
end

It’ll be strictly no slower to write that since that’s what the code should be turned into in the best case.

2 Likes

Ok, great. So I’ve got this pared down to just:

module ExpectMethod

import Base: !, &, |

abstract type TypedBool end
"""
    struct True
Typed `true`
"""
struct True <: TypedBool end
export True
"""
    struct False
Typed `false`
"""
struct False <: TypedBool end
export False

@inline !(::False) = True()
@inline !(::True) = False()
@inline (&)(::False, ::False) = False()
@inline (&)(::False, ::True) = False()
@inline (&)(::True, ::False) = False()
@inline (&)(::True, ::True) = True()
@inline (|)(::False, ::False) = False()
@inline (|)(::False, ::True) = True()
@inline (|)(::True, ::False) = True()
@inline (|)(::True, ::True) = True()

"""
    expect_method(::typeof(afunction), arguments...)

Return `True()` or `False()` depending whether or not a method using `arguments` 
is expected to exist as part of abstract interfaces.
"""
@inline expect_method(::typeof(getindex), ::AbstractArray, ::Vararg{Int}) = True()
@inline expect_method(::typeof(setindex!), ::AbstractArray, ::Vararg{Int}) = True()
@inline expect_method(::typeof(size), ::AbstractArray) = True()
@inline expect_method(::typeof(eltype), ::AbstractArray) = True()
@inline expect_method(a_function, arguments...) = False()
export expect_method

# TODO: encode the interface for every Abstract type using expect_method

end

Then you can use if condition isa True to test for the existence of key methods.

So how is this supposed to be used? AFAICT this just defines expect_method. Why is that useful and how is this related to dispatch? What’s the difference between this and hasmethod? What problem is there that hasmethod have that this fixes?

Also, why do you need True and False when true and false works just as well? (i.e. why are you still pushing things into type domain).

hasmethod doesn’t infer. I’m pushing things into the type domain because, at least last time I checked, constant propagation is disabled by recursion, so it’s not safe to use true and false to dispatch methods.

But why are you using/needing it? It’s basically not inferred because it is not really useful. Making it inferable should’t be hard.

That’s not a reason for anything really. What matters is what your function returns. It’s always better to return a simple value rather than encoding that in the type. If you really need to do tricks to make constant propagation works through your function, that’s fine and that’s your own problem in the implementation, it should not get anywhere close to the API.

why are you using/needing it?

Because I want to be able to use different methods on an object depending on what methods are expected to exist. A key example is getindex(it::Generator), which should only be defined if getindex(it.iter) expected to exist.

That’s not a reason for anything really

reduce on tuples is currently defined using recursion, so if you wanted to write mapreduce(x -> hasmethod(zero, x), &, arguments) and have it infer true and false won’t cut it.

Well, but your function is not doing that. You’ve never conditionally defined anything (i.e. there’s no dispatch defined). You’ve just defined a catch all method and have a branch in it. That’s why I said I omit the comparison to dispatch since this has none of the necessary properties.

AFAICT this is false. Unless hasmethod has recursion the recursion on the caller has little to do with this.

You’ve just defined a catch all method and have a branch in it.

Yes, but the catch all method can be inferred based on the argument types and so the branch should disappear.

AFAICT this is false

test(arguments...) = 
    if mapreduce(x -> hasmethod(zero, x), &, arguments)
        1
    else
        1.0
    end
# infers as ::Union{}
@code_warntype test(1, 1.0, 1, 1.0)

Nope, definitely true.

Again, the branch isn’t the problem. The branch is good. The issue is that you’ve never ever defined any dispatch rule with it. If you are only ever going to define a single method for your function then sure you can do whatever you want. If you actually want to do any dispatch, even defining two methods, this will not help you to do anything.

Remember you said yourself hasmethod does not infer? (so the result should not be inferred ATM whether or now recursion is a problem?)

Also remember ::Union{} inference means perfect inference that the function will always error?

That test is not actually showing anything about recursion.

And a more complicated counter example compared to the one I tested myself before posting.

julia> g(x, y, z) = mapreduce(x->x isa Integer, &, (x, y, z))
g (generic function with 1 method)

julia> @code_typed g(1.0, 1, "")
CodeInfo(
1 ─     return false
) => Bool

Oops, yes there’s a bug in the function. But your example falls apart as soon as there’s a few more arguments

g(args...) = mapreduce(x->x isa Integer, &, args)
@code_typed g(1, 1.0, "", 1, 1.0, "")