Have slightly modified function compiled without copying code

I keep running into cases where I have a function that has a base case, but sometimes I want to call it with an extra parameter so it behaves slightly differently. However, I don’t want to write two functions (lots of duplicate code), and I don’t want the runtime overhead of the adjustment when it’s not called with the extra parameter.

An oversimplified example:

# This works:
test(x, extra=0) = 2 * x + extra
# but I want the equivalent of:
test(x) = 2 * x
test(x, extra) = 2 * x + extra

I want two compiled functions: one that executes 2x and another that executes 2x + extra. Is there an existing mechanism to accomplish this? I can see how it could be done with a macro, but checking if it already exists or if there is some other/better way to accomplish it.

Note: even if the compiler is smart enough to compile away the +0 in this particular example, that’s just a contrived example. Other cases are more complex.

I’m aware of @generated, but it seems overkill for this case. Though maybe there is an easy way to accomplish it with @generated?

Perhaps not in a more complicated case but in your simplified example you can do this

test(x) = 2*x
test(x,extra) = test(x) + extra

Here that doesn’t save much but in a more complex example it could.

1 Like

Right, thanks, but it’s the more complex case I’m concerned about. For example, a 10-20 line function with an operation in the middle. So, how to deal with cases where refactoring would not be an ideal solution.

You could use something like this:

function test1(x, extra=nothing)
    a = 2*x
    b = if isnothing(extra)
        a
    else
        a + extra
    end
    c = sqrt(b)
    return c
end

The compiler will know whether extra=nothing (either through explicit assignment or the default value) so, not only will this work, but it will resolve the branch at compile time rather than run time.

Alternatively, the same concept using dispatch (which will also be resolved at compile time):

function text2(x, extra=nothing)
    a = 2*x
    b = applyextra(a, extra)
    c = sqrt(b)
    return c
end

applyextra(a, extra::Nothing) = a
applyextra(a, extra) = a + extra
3 Likes