Making high performance functions with flexible inputs

I’m writing some code to solve a couple of o.d.e systems I’m working with.

At the moment I have some code that looks like this:

    function timestep_a(init;T=1.,dt=0.0005)
                for t in 0.:dt:T
                    init=init+dt*ode_a(init)
                end
        return init
    end

    function timestep_b(init;T=1.,dt=0.0005)
                for t in 0.:dt:T
                    init=init+dt*ode_b(init)
                end
        return init
    end

 #Foo's equation
 function ode_a(input)
            return input*input-0.5*input
 end

#Bar's equation
function ode_b(arr)
           return -input*input+2*input-4.0
end

There is a lot of redundancy in the timestep functions, so I would like to make these one function. However if I do this, and pass the specific ode to solve as an argument, the code can’t optimise as well, and is much slower.

Is there a way to have my cake and eat it? Can I get julia to compile a separate timestep function with the same name for both o.d.es without actually writing it twice?

Usually julia is smart enough, to do this optimization. Are you sure that the redundant version is faster?

using BenchmarkTools
function timestep(f,init;T=1.,dt=0.0005)
            for t in 0.:dt:T
                init=init+dt*f(init)
            end
    return init
end

function timestep_b(init;T=1.,dt=0.0005)
            for t in 0.:dt:T
                init=init+dt*ode_b(init)
            end
    return init
end

#Foo's equation
function ode_a(input)
    return input*input-0.5*input
end

#Bar's equation
function ode_b(input)
    return -input*input+2*input-4.0
end

init = 1.0
@btime timestep($ode_a, $init)
@btime timestep_b($init)
  10.626 μs (2 allocations: 32 bytes)
  11.576 μs (2 allocations: 32 bytes)
3 Likes