Hi everyone,
I’ve started Julia recently and stumbled upon a case which seemed well suited to metaprogramming.
Use case
With words: I have a train
function which does roughly the same thing no matter the type of the model – declare variables, define a loss function, call this loss function – but still differs slightly depending on the model.
Basically, my code is as follows (MRE below):
abstract type Meta end
struct A <: Meta end
struct B <: Meta end
function train(x, t::A)
# Block 1: declare variables
model = ...
function loss(model)
# Block 2: compute things
end
# Block 3: log values
cur_loss = loss(model)
end
function train(x, t::B)
# Block 1: declare variables...
model = ...
# >> Add other variables
other_values = ...
function loss(model, other_values)
# Block 2: compute things
# >> Specific things with other_values
end
# Block 3: log values
cur_loss = loss(model)
# >> Custom metrics
end
To be clear, each train
function have the same blocks (1, 2, 3
in the example above) but differes marginally from each other. Sometimes I need to compute other variables, sometimes I need to call the loss with different values.
I think this would work nicely with metaprogramming.
Reproducing example
Let’s say I want to factorize the following
abstract type Meta end
struct A <: Meta end
struct B <: Meta end
function train(x, t::A)
model = x
function loss(model)
model * 2
end
loss(model)
end
function train(x, t::B)
model = x
function loss(model)
model * 3
end
loss(model)
end
I would love being able to write
for T ∈ (:A, :B)
@eval function train_ifbranches(x, t::$T)
model = x
if $t == :B
# declare other things
end
function loss(model)
# with some @eval or whatever in front
if $t == :A
model * 2
elseif $t == :B
model * 3
else
0
end
end
loss(model)
end
end
what I managed doing
for T ∈ (:A, :B)
@eval function train_meta(x, t::$T)
model = x
function loss(model)
model * $(:($T) == :A ? 2 : 3)
end
loss(model)
end
end
however, I have multiple subtypes of Meta
and can’t simply rely on ternary operator ?
. I can also declare an expression before and interpolate it, like
for T ∈ (:A, :B)
expr = if T == :A
2
else
3
end
@show expr
@eval function train_ifbranches(x, t::$T)
model = x
function loss(model)
model * ($expr)
end
loss(model)
end
end
but I would need to declare a lot of expr
, and it would become unreadable.
Final words
Clearly, I have not understand precisely the difference between the interpolating operator $
and the eval
function, even though I thought they were equivalent (according to the docs). Any help would be greatly appreciated! Of course this example is trivial, I did my best to simplify what I wanted to achieve and thought it would be better than pasting 100 lines of code. But I’d be happy to add details if that helps.
Also: What I’ve tried besides metaprogramming
At first, I simply put if
blocks everywhere, depending on the type of t
. However, it was really cumbersome and induced some unrelated problems. Notably, I had to define several loss
function (otw they would be overwritten) or I would be afraid that the autodiff package would suffer a performance penalty from checking multiple if
branches. Anyways, I would be happy to see how MP works on this case.