If blocks with code generation

Hi everyone,

I’ve started Julia recently and stumbled upon a case which seemed well suited to metaprogramming.

Use case

With words: I have a train function which does roughly the same thing no matter the type of the model – declare variables, define a loss function, call this loss function – but still differs slightly depending on the model.
Basically, my code is as follows (MRE below):

abstract type Meta end

struct A <: Meta end
struct B <: Meta end 

function train(x, t::A)
    # Block 1: declare variables
    model = ...

    function loss(model)
        # Block 2: compute things
    end 

    # Block 3: log values
    cur_loss = loss(model)
end 

function train(x, t::B)
    # Block 1: declare variables...
    model = ...
    # >> Add other variables 
    other_values = ...

    function loss(model, other_values)
        # Block 2: compute things
        # >> Specific things with other_values
    end 

    # Block 3: log values
    cur_loss = loss(model)
    # >> Custom metrics 
end 

To be clear, each train function have the same blocks (1, 2, 3 in the example above) but differes marginally from each other. Sometimes I need to compute other variables, sometimes I need to call the loss with different values.
I think this would work nicely with metaprogramming.

Reproducing example

Let’s say I want to factorize the following

abstract type Meta end

struct A <: Meta end
struct B <: Meta end

function train(x, t::A)
    model = x
    function loss(model)
        model * 2
    end
    loss(model)
end

function train(x, t::B)
    model = x
    function loss(model)
        model * 3
    end
    loss(model)
end

I would love being able to write

for T ∈ (:A, :B)
    @eval function train_ifbranches(x, t::$T)
        model = x
        if $t == :B
            # declare other things
        end
        function loss(model)
            # with some @eval or whatever in front
            if $t == :A
                model * 2
            elseif $t == :B
                model * 3
            else
                0
            end
        end
        loss(model)
    end
end

what I managed doing

for T ∈ (:A, :B)
    @eval function train_meta(x, t::$T)
        model = x
        function loss(model)
            model * $(:($T) == :A ? 2 : 3)
        end
        loss(model)
    end
end

however, I have multiple subtypes of Meta and can’t simply rely on ternary operator ?. I can also declare an expression before and interpolate it, like

for T ∈ (:A, :B)
    expr = if T == :A
        2
    else
        3
    end
    @show expr
    @eval function train_ifbranches(x, t::$T)
        model = x
        function loss(model)
            model * ($expr)
        end
        loss(model)
    end
end

but I would need to declare a lot of expr, and it would become unreadable.

Final words

Clearly, I have not understand precisely the difference between the interpolating operator $ and the eval function, even though I thought they were equivalent (according to the docs). Any help would be greatly appreciated! Of course this example is trivial, I did my best to simplify what I wanted to achieve and thought it would be better than pasting 100 lines of code. But I’d be happy to add details if that helps.

Also: What I’ve tried besides metaprogramming

At first, I simply put if blocks everywhere, depending on the type of t. However, it was really cumbersome and induced some unrelated problems. Notably, I had to define several loss function (otw they would be overwritten) or I would be afraid that the autodiff package would suffer a performance penalty from checking multiple if branches. Anyways, I would be happy to see how MP works on this case.

I only skimmed your post, so I hope I didn’t misunderstand something, but are you aware that you can nest either ifelse or the ternary operator arbitrarily deep?

julia> f(n) = ifelse(iszero(n), 10, ifelse(isone(n), 100, 1000))
f (generic function with 1 method)

julia> g(n) = iszero(n) ? 10 : isone(n) ? 100 : 1000
g (generic function with 1 method)

julia> map(f, (0, 1, 2, 3))
(10, 100, 1000, 1000)

julia> map(g, (0, 1, 2, 3))
(10, 100, 1000, 1000)

I wouldn’t use metaprograming for that. I would pass some additional parameters to your train function (perhaps a function as one of the parameters) and structure the code such that the differences can be managed by the parameters.

1 Like

I agree with @lmiq. Metaprogramming is super useful when it’s needed, but can also make your code harder to maintain, and should be reserved for cases where it’s really needed. Here, you can get a lot done with multiple dispatch on helper functions. For eg.

You can have a get_other_values function, with get_other_values(t::B) returning a named tuple of the other values, and get_other_values(t::A) returning nothing. Then within loss,

you can have a do_things_with_other_values call, with one method of it accepting named tuples and computing things with its values, another method being do_things_with_other_values(::Nothing) = return .

These are just examples to illustrate the idea, a lot depends on the specifics of your code. Exploiting the type system and dispatch can give you performant, generic code, while being much more readable and maintainable than a metaprogramming approach.

2 Likes

Thanks for your answers!
Indeed that is what I tried in the first place but my main concern was not being able to write neat type-stable functions. On top of that, I had difficulties obtaining performant code with Enzyme for the autodiff, which made me a bit paranoid on the structure of the loss function. All this made me think that MP would be more efficient. But I’ll try your way @digital_carver and will let you know how that works!
If someone else has suggestion on the MP approach I’d be interested still :slight_smile:

Yeah, those are understandable concerns.

I’d suggest trying out creative non-MP approaches first, including multiple dispatch and creating and using generic types where needed, but if it turns out you really do need metaprogramming ultimately, @generated might be the tool you need (and CompTime.jl provides a nice layer on top of it to make it easier to use).

I ended up as @lmiq and @digital_carver suggested, and indeed I can see how using MP makes the code harder to maintain. I don’t even have type instability, so that’s good. Thanks for your help!

2 Likes