Performant and user-friendly function only through metaprogramming?

Hi, I have a question on how to write performant functions that are still somewhat convenient for the user.

I have an application where I need to run optimizers on custom functions, i.e., the functions are called very often and thus need to be as performant as possible.
Usually, these functions have the same structure but different coefficients.
Here is a simple example where the custom function is a parabola in d dimensions:
f([x_1, ..., x_d]) = c_{1} x_1 + c_ {2} x_1^ 2 + c_{3}x_2 + c_{4}x_2^2 + c_{5} x_1x_2 + \dots + c_{N-1}x_d + c_{N}x_d^2
Here c are known numerical coefficients (the values of c are different for different applications) and x_i are the unknown parameters to be scanned/optimized.

Here is a MWE to show my problem:

# This is the user input (consider an example with 5 parameters named A, B, C, D, E)

using NamedArrays 

# user input: 
my_coefs = (
	A = (-0.104422, 0.004119), #(linear, quadratic)
	B = (0.109999, 0.005719),
	C = (0.068538, 0.004162),
	D = (-0.004123, 0.004186),
	E = (-0.002558, 0.000930),
)

names = collect(keys(my_coefs))
d = length(my_coefs)
my_crossm = NamedArray(rand(d,d), (names, names))  # user input (random values just for the purpose of this example)

I want to implement it in a rather flexible function, i.e. for a varying number of parameters and whether to consider cross terms or not etc.
So something like this:

function user_friendly(coefs, crossm; use_quadratic=true, use_crossterms=true)
    x -> begin
        result = 0.
        for c in keys(coefs) # linear terms
            result += getindex(coefs, c)[1] * getindex(x, c)
        end
 
        if use_quadratic # quadratic terms
            for c in keys(coefs)   
                result += getindex(coefs, c)[2] * getindex(x, c)^2
            end
        end
 
        if use_crossterms # cross terms
            for c1 in keys(coefs), c2 in keys(coefs)
                if c1 != c2 
                    result += crossm[c1, c2] * getindex(x, c1) * getindex(x, c2)
                end
            end
        end

        return result
    end
end

Here is another implementation that is not as flexible and does not allow to simply change the number of parameters or to switch off the cross terms:

function  hardcoded(coefs, crossm)
	x -> begin
		return (
		coefs.A[1] * x.A + coefs.B[1] * x.B + coefs.C[1] * x.C + coefs.D[1] * x.D + coefs.E[1] * x.E
		+ coefs.A[2] * x.A^2  + coefs.B[2] * x.B^2  + coefs.C[2] * x.C^2  + coefs.D[2] * x.D^2  + coefs.E[2] * x.E^2
		+ crossm[:A, :B] * x.A * x.B + crossm[:A, :C] * x.A * x.C + crossm[:A, :D] * x.A * x.D + crossm[:A, :E] * x.A * x.E
		+ crossm[:B, :A] * x.B * x.A + crossm[:B, :C] * x.B * x.C + crossm[:B, :D] * x.B * x.D + crossm[:B, :E] * x.B * x.E
		+ crossm[:C, :A] * x.C * x.A + crossm[:C, :B] * x.C * x.B + crossm[:C, :D] * x.C * x.D + crossm[:C, :E] * x.C * x.E
		+ crossm[:D, :A] * x.D * x.A + crossm[:D, :B] * x.D * x.B + crossm[:D, :C] * x.D * x.C + crossm[:D, :E] * x.D * x.E
		+ crossm[:E, :A] * x.E * x.A + crossm[:E, :B] * x.E * x.B + crossm[:E, :C] * x.E * x.C + crossm[:E, :D] * x.E * x.D
		)
	end
end

The speed comparison yields:

using BenchmarkTools

f_hardcoded = hardcoded(my_coefs, my_crossm)
f_user_friendly = user_friendly(my_coefs, my_crossm)

rparams = rand(5)
test_params = (A=rparams[1], B=rparams[2], C=rparams[3], D=rparams[4], E=rparams[5])

@btime f_hardcoded(test_params) # 156.367 ns (1 allocation: 16 bytes)
@btime f_user_friendly(test_params) # 2.463 μs (121 allocations: 4.39 KiB)

So the hardcoded function is way faster and not allocating as much memory.
I’m pretty sure one can optimize the user_friendly function, but I tried a couple of things and using the for loops etc. seemed to be generally slower than just doing something like return (coef[1] * x.A + ...). Maybe I’m missing something here…

But since all the things like how many parameters to use or whether to use cross terms or not has to be decided only once (before running the optimizer), I thought about using metaprogramming to have an automatically generated function that is then passed to the optimizer.

function generate_f(coefs, crossm; use_quadratic=true, use_crossterms=true)
    S = ""

    for c in keys(coefs) # linear terms
        S *= " + coefs.$c[1] * x.$c"
    end

    if use_quadratic # quadratic terms
        for c in keys(coefs)   
            S *= " + coefs.$c[2] * x.$c^2"
        end
    end

    if use_crossterms # cross terms
        for c1 in keys(coefs), c2 in keys(coefs)
            if c1 != c2 
                S *= "+ crossm[:$c1, :$c2] * x.$c1 * x.$c2" 
            end
        end
    end

    return Meta.parse(S)
end

@eval my_generated_func(coefs, crossm) = x -> begin $(generate_f(my_coefs, my_crossm)) end
f_generated = my_generated_func(my_coefs, my_crossm)

@btime f_generated(test_params) # 157.945 ns (1 allocation: 16 bytes)

This is then as performant as the hardcoded function. However, I’m not really experienced with using metaprogramming, so here are my questions:

  • Is using metaprogramming the right way to deal with this kind of problem?
    (or is there a way to get the user_friendly function as performant as the hardcoded one?)
  • If metaprogramming is the way to go, is there a better way to generate this f_generated function?

Why not just let the user pass in a function argument? Simulating higher-order functions is a classic mis-use of metaprogramming.

Or, if you specifically want to handle quadratic functions f(x) = x^T A x + b^T x + \alpha, why not have the user pass in the matrix A, the vector b, and the scalar \alpha and then call dot(x,A,x) + dot(b, x) + α as needed? This has the advantage that you can use specialized methods exploiting the quadratic structure. (They can even pass in StaticArrays and then dot(x,A,x) + dot(b, x) + α will be completely unrolled by the compiler. Or pass in sparse arrays if A is large and only a few entries are nonzero.)

If you are getting lots of allocations for evaluating a simple quadratic function then you are doing something wrong, e.g. you have type-unstable code.

PS. Never do metaprogramming using strings; you always work with the AST directly. But I don’t think you should be using metaprogramming at all.

7 Likes