Speeding Up Function with ::Function and ::Vararg input

I need help understanding how/if I can speed up a function that has a ::Function argument and a ::Vararg argument.

I have several functions that do almost the same thing. They each receive a few vector arguments, calculate some new vectors of the same size, then do some nested while looping to set the values of an output array. The looping part is almost identical across the different functions. The only difference is the function used to set the value of the output array inside the loops. So, I’m hoping to write a single function that takes care of this looping part once, then reuse it.

But because that inner loop function could take different numbers of arguments, I need to use a ::Vararg input. I’ve tried and failed to get one that is as fast as having several nearly identical functions. An example is below.

test(x, y) = (x^2 + y^2)^(1/3)
test(x, y, z) = (x^2 + y^2 + z^2)^(1/4)

function funvar!(a::Vector, f::Function, V::Vararg{Vector{Float64},N}) where {N}
    T = Tuple(zip(V...))
    for i = 1:length(a)
        a[i] = f(T[i]...)
    end
end

function fun2!(a::Vector, f::Function, b::Vector, c::Vector)
    for i = 1:length(a)
        a[i] = f(b[i], c[i])
    end
end

function fun3!(a::Vector, f::Function, b::Vector, c::Vector, d::Vector)
    for i = 1:length(a)
        a[i] = f(b[i], c[i], d[i])
    end
end

N = 10000000
a = zeros(N)
b = rand(N)
c = rand(N)
d = rand(N)

@time funvar!(a, test, b, c)
@time funvar!(a, test, b, c, d)

@time fun2!(a, test, b, c)
@time fun3!(a, test, b, c, d)


  3.095956 seconds (50.03 M allocations: 1.343 GiB, 6.75% gc time)
  3.217017 seconds (60.04 M allocations: 1.641 GiB, 5.49% gc time)

  0.953007 seconds (16.32 k allocations: 902.345 KiB)
  0.899082 seconds (16.80 k allocations: 923.987 KiB)

The conversion to a Tuple here should not be needed here and will be very expensive, especially in terms of compile times, if the Vs are big. Just zip(V...) should perform a lot better.

1 Like

The issue is that I can’t actually just iterate through V. I need to be able to access specific elements of the varargs in the real function that I’m dealing with. Is there a way to get around that?

You can iterate through T just fine, it’s just not indexable. Looks like you are looking for something like this:

function funvar!(a::Vector, f::Function, V::Vararg{Vector{Float64},N}) where {N}
    T = zip(V...)
    for (i, Ti) in zip(eachindex(a), T)
        a[i] = f(Ti...)
    end
end

(eachindex is typically preferable to 1:length(a), since it will also work for OffsetArrays)
In Julia 1.6, you could also write this as:

function funvar!(a::Vector, f::Function, V::Vararg{Vector{Float64},N}) where {N}
    for (i, Ti...) in zip(eachindex(a), V...)
        a[i] = f(Ti...)
    end
end

No I actually need to be able to index the varargs in my real function because it has more complicated things in it than just the one for loop. Anyway, I profiled the real function and the Tuple construction only took a small fraction of the execution time. It blows up the memory usage but I don’t know of a better way.

You can at least collect T into a Vector instead of Tuple, which should be more efficient for large arrays. In that situation, I would probably just skip the zip though and do

function funvar!(a::Vector, f::Function, V::Vararg{Vector{Float64},N}) where {N}
    for i in 1:length(a)
        a[i] = f(map(v -> v[i], V)...)
    end
end

directly.

1 Like

Ok, good suggestion. If I use collect instead of Tuple, the times I get for the example above are

2.394363 seconds (40.03 M allocations: 764.435 MiB, 4.89% gc time)
2.361993 seconds (50.03 M allocations: 993.369 MiB, 3.34% gc time)
0.949158 seconds (16.24 k allocations: 899.298 KiB)
0.897008 seconds (16.71 k allocations: 919.925 KiB)

So it helps a little but definitely doesn’t solve the problem. Here’s the @profview for the second call on funvar!.

The collecting happens in that little blue box on the left and the math happens in the purple and green boxes on the right. Then there’s a bunch of time in between that is ascribed to funvar! itself. Don’t know what’s happening there.

Hmm, that is indeed odd. I can replicate this as well, even on Julia nightly. @code_warntype here looks fine, so it would be good to investigate this further. Would you mind opening an issue on GitHub?

I could see two simple minded solutions. Perhaps they don’t work in your context, but here they are:

If the possible number of arguments is small (as in your example), branching in the loop may be performant:

    if length(a) == 2
        a[i] = f(a, test, b, c)
    else
        a[i] = f(a, test, b, c, d)
    end

Alternatively, all of methods of test could accept the same number of arguments (if feasible). The unused arguments would be dummies that support indexing:

test(x,y, z :: Missing) = x + y;
test(x,y,z :: Float64) = x + y + z

Then in the signature of funvar! you could hard wire the arguments with dummies as optional default arguments:

struct Dummy end;
Base.getindex(x :: Dummy, i :: Integer) = missing;

function funvar!(a, f, a :: Vector{Float64}, b = Dummy(), c = Dummy())
  for i = 1 : length(a)
    a[i] = f(a[i], b[i], c[i])
  end
end

Thanks for the ideas. The number of varargs is always going to be pretty small, so would be easy to use some if-else branches. If I change funvar! to

function funvar!(a::Vector, f::F, V::Vararg{Vector{Float64},N}) where {F<:Function,N}
    T = collect(zip(V...))
    #L = length(V)
    for i = 1:length(a)
        if N == 2
            a[i] = f(T[i][1], T[i][2])
        elseif N == 3
            a[i] = f(T[i][1], T[i][2], T[i][3])
        end
    end
end

the speed is almost equal:

  0.817415 seconds (30.76 k allocations: 154.274 MiB)
  0.823445 seconds (33.78 k allocations: 230.716 MiB, 3.47% gc time)
  0.766421 seconds (16.26 k allocations: 898.938 KiB)
  0.712799 seconds (16.73 k allocations: 919.835 KiB)

The f function is going to be called a huge number times though and I’m trying to avoid any overhead. Does this mean that something about splatting in f, which is passed to funvar!, is causing the slowdown?

@simeonschaub’s answer is just missing an explicit specialization on f (see Performance Tips · The Julia Language ). Here’s a small change that gives roughly equal performance to fun2 and fun3:

function funvar_new!(a::Vector, f::F, V::Vararg{Vector{Float64},N}) where {F <: Function, N}
    for i in 1:length(a)
        a[i] = f(map(v -> v[i], V)...)
    end
end

Profiling results (note that I decreased N and used @btime from BenchmarkTools.jl):

N = 1000
a = zeros(N)
b = rand(N)
c = rand(N)
d = rand(N)
julia> @btime funvar!($a, $test, $b, $c);
  118.359 μs (4493 allocations: 133.08 KiB)

julia> @btime funvar!($a, $test, $b, $c, $d);
  127.373 μs (5494 allocations: 164.28 KiB)

julia> @btime fun2!($a, $test, $b, $c);
  44.855 μs (0 allocations: 0 bytes)

julia> @btime fun3!($a, $test, $b, $c, $d);
  46.708 μs (0 allocations: 0 bytes)

julia> @btime funvar_new!($a, $test, $b, $c);
  45.288 μs (0 allocations: 0 bytes)

julia> @btime funvar_new!($a, $test, $b, $c, $d);
  47.951 μs (0 allocations: 0 bytes)

2 Likes

Ah, that’s right. I did try this with the collect(zip(V...)) though and it didn’t make a difference there, which seems odd to me.

For me, the explicit branching solution has the same performance (approximately) as separate methods:

function funvar4!(a, f, V::Vararg{Vector{Float64},N}) where {N}
    n = length(V);
    for i = 1 : length(a)
        if n == 2
            a[i] = f(V[1][i], V[2][i]);
        else
            a[i] = f(V[1][i], V[2][i], V[3][i]);
        end
    end
end

# With rdeits's inputs:
@btime funvar4!($a, $test, $b, $c);
  18.427 μs (0 allocations: 0 bytes)

But I see the drawback of the dynamic dispatch. This is why I thought the dummy solution might be useful (but in the end it runs at the same speed at rdeits’ more elegant solution).

What I don’t understand about rdeits’ solution is why explictly declaring f makes a difference (it doesn’t for my suggestions). My reading is that the method should specialize given that f is used in the body of funvar_new!. Perhaps the catch is in the “should”?

1 Like

Ok, thank you to everyone for help and suggestions. I’m really glad to be getting some feedback. I followed @rdeits’ example in the larger, non-example functions that I’m trying to generalize, and I’m still seeing performance issues with the vararg function. So, here is an expanded example to demonstrate. I know it’s getting long, but it’s as close to the real functions I’m using as I can make it without a bunch of extra junk.

test(Δx, a, b) = Δx + (a^2 + b^2)^(1/3)
test(Δx, a, b, c) = Δx + (a^2 + b^2 + c^2)^(1/4)

function funvar!(σ::AbstractVector{<:Real},
                 x::Vector{Float64},
                 xl::Vector{Float64},
                 f::F,
                 A::Vararg{Vector{Float64},N}) where {F <: Function, N}
    L = length(xl)
    jstart = 1
    for i = 1:length(x)
        j = jstart
        while (j <= L) && abs(x[i] - xl[j]) > 0.1
            j += 1
        end
        if j <= L
            jstart = j
            while (j <= L) && abs(x[i] - xl[j]) < 0.1
                Δx = x[i] - xl[j]
                #this is the only line that changes to accommodate varargs
                σ[i] = f(Δx, map(a->a[j], A)...)
                j += 1
            end
        end
    end
end

function fun2!(σ::AbstractVector{<:Real},
               x::Vector{Float64},
               xl::Vector{Float64},
               a::Vector{Float64},
               b::Vector{Float64})
    L = length(xl)
    jstart = 1
    for i = 1:length(x)
        j = jstart
        while (j <= L) && abs(x[i] - xl[j]) > 0.1
            j += 1
        end
        if j <= L
            jstart = j
            while (j <= L) && abs(x[i] - xl[j]) < 0.1
                Δx = x[i] - xl[j]
                #the first method of the "test" function
                σ[i] = test(Δx, a[j], b[j])
                j += 1
            end
        end
    end
end

function fun3!(σ::AbstractVector{<:Real},
               x::Vector{Float64},
               xl::Vector{Float64},
               a::Vector{Float64},
               b::Vector{Float64},
               c::Vector{Float64})
    L = length(xl)
    jstart = 1
    for i = 1:length(x)
        j = jstart
        while (j <= L) && abs(x[i] - xl[j]) > 0.1
            j += 1
        end
        if j <= L
            jstart = j
            while (j <= L) && abs(x[i] - xl[j]) < 0.1
                Δx = x[i] - xl[j]
                #the second method of the "test" function
                σ[i] = test(Δx, a[j], b[j], c[j])
                j += 1
            end
        end
    end
end

N = 100
σ = zeros(N)
x = collect(LinRange(0, 1, N))

N = 1000
xl = sort(rand(N))
a = rand(N)
b = rand(N)
c = rand(N)

@btime funvar!($σ, $x, $xl, $test, $a, $b)
@btime funvar!($σ, $x, $xl, $test, $a, $b, $c)
@btime fun2!($σ, $x, $xl, $a, $b)
@btime fun3!($σ, $x, $xl, $a, $b, $c)

  8.343 ms (222765 allocations: 3.40 MiB)
  9.185 ms (241628 allocations: 3.69 MiB)
  1.537 ms (0 allocations: 0 bytes)
  1.481 ms (0 allocations: 0 bytes)

The difference from the initial example is the while loops to prevent extra calls to test when values of x and xl are too far apart. But the only difference between funvar! and the other functions is that one line where test/f is called. I have the explicit specialization on f and no more Tuple or collect calls, so I don’t understand why funvar! is slower. I’m also not sure why funvar! is triggering memory allocation. Maybe that’s the only thing slowing it down?

The goal is to avoid having a bunch of functions with this nearly identical for-while-while loop combination inside them.

(also I’m using version 1.5.3, in case that matters)

For me, the dummy solution runs at the speed of fun2! and fun3!:

struct Dummy end
Base.getindex(x :: Dummy, i :: Integer) = missing;

dtest(Δx, a, b, c :: Missing) = Δx + (a^2 + b^2)^(1/3)
dtest(Δx, a, b, c :: Float64) = Δx + (a^2 + b^2 + c^2)^(1/4)

function funvar2!(σ::AbstractVector{<:Real},
    x::Vector{Float64},
    xl::Vector{Float64},
    f::F,
    a, b, c) where {F <: Function}

    L = length(xl)
    jstart = 1
    for i = 1:length(x)
        j = jstart
        while (j <= L) && abs(x[i] - xl[j]) > 0.1
            j += 1
        end
        if j <= L
            jstart = j
            while (j <= L) && abs(x[i] - xl[j]) < 0.1
                Δx = x[i] - xl[j]
                #this is the only line that changes to accommodate varargs
                σ[i] = f(Δx, a[i], b[i], c[i]);
                j += 1
            end
        end
    end
end

julia> d1 = Dummy()
Dummy()

julia> @btime funvar2!($σ, $x, $xl, $dtest, $a, $b, $d1)
  336.124 μs (0 allocations: 0 bytes)
julia> @btime funvar2!($σ, $x, $xl, $dtest, $a, $b, $c)
  354.818 μs (0 allocations: 0 bytes)
julia> @btime fun3!($σ, $x, $xl, $a, $b, $c)
  349.150 μs (0 allocations: 0 bytes)

It also looks like the clearest solution to me. But then I’m not the expert here.

:+1: This seems like what I’ll do if there isn’t any really clean solution

It seems like using A::Vararg{Vector{Float64},2} and A::Vararg{Vector{Float64},3} will also be fast. I’d basically have two versions of the same code, but that’s better than 5 or 6 copies.

I don’t understand the issue though. It must be splatting that causes the slowdown?

Don’t give up so soon! :wink:

The following change makes funvar! just as fast again:

                args = let j = j
                    ntuple(i -> A[i][j], N)
                end
                σ[i] = f(Δx, args...)

in context:

function funvar!(σ::AbstractVector{<:Real},
                 x::Vector{Float64},
                 xl::Vector{Float64},
                 f::F,
                 A::Vararg{Vector{Float64},N}) where {F <: Function, N}
    L = length(xl)
    jstart = 1
    for i = 1:length(x)
        j = jstart
        while (j <= L) && abs(x[i] - xl[j]) > 0.1
            j += 1
        end
        if j <= L
            jstart = j
            while (j <= L) && abs(x[i] - xl[j]) < 0.1
                Δx = x[i] - xl[j]
                args = let j = j
                    ntuple(i -> A[i][j], N)
                end
                σ[i] = f(Δx, args...)
                j += 1
            end
        end
    end
end

The let j = j thing is necessary to avoid the weird performance issue with boxed variables in closures ( Performance Tips · The Julia Language)

Results:

julia> @btime funvar!($σ, $x, $xl, $test, $a, $b)
  890.985 μs (0 allocations: 0 bytes)

julia> @btime funvar!($σ, $x, $xl, $test, $a, $b, $c)
  938.801 μs (0 allocations: 0 bytes)

julia> @btime fun2!($σ, $x, $xl, $a, $b)
  851.805 μs (0 allocations: 0 bytes)

julia> @btime fun3!($σ, $x, $xl, $a, $b, $c)
  911.132 μs (0 allocations: 0 bytes)
2 Likes

That does it! I’m still trying to understand the let block and variable boxes, but the goal is met.

Seems like I can also use

let k = j
    σ[i] = f(Δx, map(a->a[k], A)...)
end

without any speed loss.