How to pass compile-time user settings to macro?

I have a module that develops a kernel using a macro based on settings that a user provides. I need those settings to be compile time available, otherwise the macro won’t work. In this current example, I do this with const globals, but I would like to make this more flexible. I would like to enable the user to provide them in the code at the bottom, rather than to set them in the module. What is the Julian way of doing this?

## This is the module that will be part of the package.
module Kernels

export kernel!

const do_a = false
const do_b = true
const do_c = false

macro make_kernel()
    ex_rhs_list = []

    if do_a
        push!(ex_rhs_list, :(log.(a[:])))
    end

    if do_b
        push!(ex_rhs_list, :(- sin.(b[:])))
    end

    if do_c
        push!(ex_rhs_list, :(- cos.(c[:])))
    end

    if length(ex_rhs_list) == 0
        ex = quote 
            function kernel!(at, a, b, c)
            end
        end
    else
        if length(ex_rhs_list) == 1
            ex_rhs = ex_rhs_list[1]
        else
            ex_rhs = Expr(:call, :+, ex_rhs_list...)
        end

        ex = quote 
            function kernel!(at, a, b, c)
                @. at[:] += $ex_rhs
            end
        end
    end

    print(ex)
    return esc(ex)
end

@make_kernel

end


## This is the script on the user side.
using BenchmarkTools
using .Kernels

n = 2^16
at = rand(n); a = rand(n); b = rand(n); c = rand(n)

@btime kernel!($at, $a, $b, $c)

You could pass them via ENV variables (maybe?) but that’s not really how macros should be used. You’ll have a lot of problems with precompilation. It’s hard to suggest a solution without understanding your problem better, but FWIW multiple dispatch is a 0-cost abstraction when used well, so it might work better than your macro. Eg. have the user write

const my_kernel = Kernel{false,true,false}()

function my_big_foo()
    ...
    kernel!(my_kernel, at, a, b, c)
end

Then on your side, it’ll look like

struct Kernel{A, B, C}
end

function kernel!(::Kernel{A,B,C}, at, a,b,c) where {A, B, C}
    for ind in eachindex(at)
        if A
            at[ind] += log(a[ind])
        end
        ...
    end
end

Because A, B, C are struct parameters, they’ll behave as compile-time constant, and it should be just as efficient as your macro expansion, given some more work (maybe with @inbounds and @simd?)

Kernel{true, false, false} is not particularly elegant design, but given the info you gave us, it looks like a good start…

1 Like

Also, if you’re broadcasting over sufficiently large arrays, I would seriously check if this plain code isn’t just as efficient:

struct Kernel
    A::Bool
    B::Bool
    C::Bool
end

...
function kernel!(k::Kernel, at, a,b,c) 
    at .+= ifelse(kernel.A, log.(a), 0) .+ ifelse(kernel.B,  ...
end
1 Like

Thanks. My actual use case is to build complex computational kernels that consists of nested loops. The macro builds the inner loop based on the settings provided by the user and tries to fuse operations depending on the compatibility of chosen settings. I hope to be able to avoid an if structure as the number of permutations is rather large.

Would

@make_kernel false true false

be acceptable on the user side? Then your make_kernel can just be macro make_kernel(do_a,do_b,do_c) with a few modifications…

I don’t think there’s a bullet-proof solution without seeing more of the kind of code you want to generate. What kind of nested loop? One elegant design is to store closures in the Kernel struct, like (x->x+2, y->y+3). You can get efficient and flexible code this way. For nested loops, you’d want a recursive function call…

1 Like

This is my actual use case in case you are interested: https://github.com/Chiil/MicroHH.jl/blob/main/test/dynamics_kernel_switch.jl, to give you an idea of the type of functions I am trying to build.

You might look at GitHub - mcabbott/Tullio.jl: ⅀ which is configurable with keyword arguments.

1 Like

I think that a tuple of functions would be very elegant in this case. Consider that it’s just as efficient:

julia> apply_add(fs::Tuple, args...) = mapreduce(f->f(args...), +, fs)
apply_add (generic function with 1 method)

julia> using BenchmarkTools

julia> f1(x, y) = sin(x) * y
f1 (generic function with 1 method)

julia> f2(x, y) = cos(x) - y
f2 (generic function with 1 method)

julia> f(x, y) = f1(x, y) + f2(x, y)
f (generic function with 1 method)

julia> const ftup = (f1, f2)
(f1, f2)

julia> @btime apply_add(ftup, 10.0, 5.0)
  1.538 ns (0 allocations: 0 bytes)
-8.559177083523302

julia> @btime f(10.0, 5.0)
  1.538 ns (0 allocations: 0 bytes)
-8.559177083523302

So you could have functions like

advec(     wt, u, v, w, s,
                visc,
                dxi, dyi, dzi, dzhi,
                is, ie, js, je, ks, ke) = - gradx(interpz(u) * interpx(w))
...

which could, with a macro for convenience, be

@def_kernel_fun advec =  - gradx(interpz(u) * interpx(w))

Then the user would specify

my_kernel = (advec, buoy)

instead of (true, false, true).

I don’t understand why you alternate the do_advec and do_diff calls in your ex_rhs RHS sum. Isn’t the sum commutative, so you could bunch all the do_advec terms? Or am I missing something?

In any case, your loop would be

                @fast3d begin
                    @fd (wt, u, v, w, s) wt += apply_add(kernel_function_tuple, wt, u, v, w, s, etc...)
                end

Does that make sense?

1 Like

That is an interesting approach, I will try that out! The alternating if statements need cleanup, sorry about that.

1 Like