 Type stability with variable arguments

Is there a way to maintain type stability with variable arguments? Below is some simplified code illustrating the issue. I’d like f(x,y,z) to perform like g(x,y,z):

function f( args... )
s = 0.0
for arg in args
for i = 1:length(arg)
s += arg[i]
end
end
return s
end

function g( x::Vector{Float64}, y::Vector{Float64}, z::Vector{Int} )
s = 0.0
for arg in (x,y,z)
for i = 1:length(arg)
s += arg[i]
end
end
return s
end

x = randn(10_000_000);
y = randn(10_000_000);
z = rand( 1:10, 10_000_000);

@time f( x,y )
@time f( x,y,z )
@time g( x,y,z )

This might not be possible in your use case, but breaking out the inner loop will allow the compiler to optimize your code. For example:

function new_sum(s, arg)
for i in 1:length(arg)
s += arg[i]
end
return s
end

function h(args...)
s = 0.0
for arg in args
s = new_sum(s, arg)
end
return s
end

Will be just as fast (on my computer, slightly faster) than your g function.

Not sure if I understood, but the source of type instability is not that a function has many distinct methods. Instead, the problem happens inside body methods in which there is no single mapping from each combination of input types to combination of output types.

I don’t know how general the solution is supposed to be, but this version of the algorithm works pretty well:

f2(args...) = sum(sum(arg) for arg in args)

In general, putting inner calculations into a different function (here sum from Base) helps a lot with optimisation.

Thanks. That does solve the simplified problem but I don’t see how to apply it to my real use cases. One of them is sort on multiple large vectors using lexicographic ordering as illustrated below. Do you see something similar here? I feel like I somehow need to unroll the loop so that julia can assign a stable type to the loop variable.

function lexicographic( vs... )
function lt( x::Int, y::Int )
for v in vs
if v[x] < v[y]
return true
elseif v[x] > v[y]
return false
end
end
return false
end
end

n = 2_000_000;
x = rand([:a,:b],n);
y = rand(1:10, n);
z = randn(n);

indices = collect(1:n);
@time sort!( indices, lt=lexicographic( x, y, z ) );

This example is considerably different from the previous ones. You are returning a closure. Closures in Julia may have performance problems because they are created at a point in time in which the arguments (the object passed as parameters) are not yet available, so they cannot specialize for the arguments of the function that returns the closure. The type of the arguments of lexicographic would need to be inferred before the function is called, what is not possible with a VarArg parameter.

I am not sure if variable arguments are a factor here. I made experiment with

function lexicographic2(a,b,c )
function lt( x::Int, y::Int )
for v in (a,b,c)
if v[x] < v[y]
return true
elseif v[x] > v[y]
return false
end
end
return false
end
end

and @time sort!( indices, lt=lexicographic2( x, y, z ) ); shows the same result as for the original function (Julia 1.7).

For this case, using by seems to be much better than using lt:

julia> let

n = 200_000
x = rand([:a,:b], n)
y = rand(1:10, n)
z = randn(n)

indices = collect(1:n)
@time sol1 = sort(indices, lt=lexicographic(x, y, z))
@time sol2 = sort(indices, by = i -> (x[i], y[i], z[i]))
sol1 == sol2
end
1.904496 seconds (46.49 M allocations: 711.346 MiB, 4.01% gc time, 0.46% compilation time)
0.214283 seconds (97.60 k allocations: 6.968 MiB, 58.60% compilation time)
true

Ordering for tuples is already lexicographic as noted by @tomerarnon, so you can just map into a tuple:

lexicographic(x...) = i -> getindex.(x, i)

and sort(..., by=lexicographic(args...)).

In general, it is instructive to see how Base implements functions acting on variable-length arguments or variable-length tuples in a type-stable way. The key thing is that, because the compiler knows the length of a tuple, if you write things in the correct way then it can completely unroll loops as if you had written it out by hand.

For example, here is how isless is implemented for tuples: base/tuple.jl:isless. Note that it is recursive, but the compiler can “unroll” the recursion completely (if it is not too deep). map is implemented similarly, and also broadcasting on tuples as exploited above in the getindex.(...) call. Another useful pattern is to use the ntuple function with a Val{N} argument where N comes from a type parameter (e.g. a Vararg parameter for variable arguments), in which case the compiler can completely unroll the loop; a simple example of this can be found in the size(a::Array) function and much more extensive use can be found in the multidimensional reverse! implementation.

3 Likes

In your new example the types of the lexicographic2 parameters is also not known at parse time, only at call time. However this does not seem to matter as much as I thought. For me, if I am in global scope, then both lexicographic2 and a new version with the types of the parameters defined have the same performance (about 11.5s in my machine) while the original lexicographic takes ~14.5s (I am using Julia 1.5.4). I really expected more of the overhead to come from closures. However, using let to create a local scope makes basically all times very close to 1s, and multiple runs shows each of the three implementations sometimes taking the last place (so no significant difference).

1 Like

Thanks for the response.

You do need to reinitialize indices after sorting the first time because I’ve noticed that julia’s sort is faster for already sorted vectors. I think the difference in speed that you’re seeing is due to that. You’ll notice that you have 10’s of millions of allocations for only 200,000 length vectors. I assume this is due to type instability. stevengj’s approach does eliminate the allocations. I’m not 100% sure of the reason for the difference.

That’s not how type inference works in Julia. The key thing is what happens at compile-time, not at parse time (the parser only knows how things are spelled).

It’s totally fine to declare lexicographic2(a,b,c) with no types for the arguments — the compiler knows where the function is called, and (assuming the call site is type-stable/grounded/inferable) it compile a specialized version of lexicographic2 (and hence a specialized version of the returned function lt) for those argument types.

The problem with lexicographic2 is that the local variable v is type-unstable — it takes on 3 different types because (a,b,c) has 3 different types. If instead you write out the loop:

function lexicographic3(a,b,c )
function lt( x::Int, y::Int )
if a[x] < a[y]
return true
elseif a[x] > a[y]
return false
elseif b[x] < b[y]
return true
elseif b[x] > b[y]
return false
elseif c[x] < c[y]
return true
elseif c[x] > c[y]
return false
else
return false
end
end
return lt
end

then the type-instability of v goes away and the performance is good:

julia> @btime sort!( \$indices, lt=\$(lexicographic3( x, y, z )));
48.989 ms (0 allocations: 0 bytes)

(Note the lack of allocations — if you have any allocations here, that is indicative of a type-instability.)

The trick is to get the compiler to unroll the loop for you, for any number of arguments. I usually use map or ntuple for this, since it is specialized for tuples, and it is fine to construct a tuple result that is discarded (the compiler will eliminate the tuple allocation). I’d rather use foreach, but that is currently not specialized for tuples (julia#31901):

function lexicographic4(args...) # no type declarations needed here: the compiler will figure it out
function lt(x, y) # similarly, no type-declarations needed here either
map(args) do v  # the compiler will completely unroll a map of a tuple
if v[x] < v[y]
return true
elseif v[x] > v[y]
return false
end
end # map result (a tuple) is discarded … compiler will elide it
return false
end
return lt
end

which gives:

julia> @btime sort!( \$indices, lt=\$(lexicographic4( x, y, z )));
44.779 ms (0 allocations: 0 bytes)

as desired.

6 Likes

Yes, but since the main concern here is the type-instability, it doesn’t really matter. The thing you’re looking for is that sort! should report 0 allocations, and should perform similarly to a hand-unrolled comparison.

I was talking about the specific problem with closures, not type-instability in general.

There’s no problem in this case with closures, as my lexicographic3 and lexicographic4 examples demonstrate.

Maybe I’m looking at this really wrong, but lexicographic4 seems like it behaves differently. In the other versions, all the returns belong to the nested function lt. However, in lexicographic4, the first couple returns would belong to the do-block function, right? So lexicographic4's lt must map that function over every array in args/(a, b, c)/vs and finally return false, whereas the other versions of lt may return true or false, likely without checking all the arrays.

Whoops, right. Probably you need a mapreduce-like call here. Or implement it recursively like isless for tuples. In this particular case it’s a lot easier to just map to a tuple and use sort!(..., by=...) as in this post.

The basic point remains that you need to get the compiler to unroll the loop for iteration over a heterogeneous tuple to be fast.