In retrospective, this is a kinda obvious transformation I have forgotten about. You can treat falses as zeros and trues as ones (or the contrary) and solve by sorting the array by the value associated, I had also forgotten that mergesort did not need a temporary vector.
However, your solution is the one that takes more time and memory, for at least an order of magnitude more than the others, I have run the code below and what I get is:
# one-liner that create two separate vectors using two filters
7.490 μs (19 allocations: 16.69 KiB)
# first post code, uses aux vector for permutation
3.255 μs (3 allocations: 15.91 KiB)
# first post code, but I take the permutation vector as parameter
# and discovered that permute! makes a copy of the permutation
# vector so I use Base.permute!! instead
2.385 μs (1 allocation: 32 bytes)
# my code that does not preserve order and return views
507.447 ns (3 allocations: 128 bytes)
# your code based on mergesort
49.923 μs (4995 allocations: 234.14 KiB)
The problem with your code is that it calls recursively while allocating views, and you base case is a single element. So in the end you make 2*length(arr)
calls and allocate considerably more 2*length(arr)
views (because there is not only the views for calling recursively as there are the ones for reverse!
too). Each view is allocated in the heap and take more memory than a Float64, that is the eltype
of the array I used to test, so this allocates far more than an aux vector, which @louisponet wanted to avoid. Maybe if a pair of Int64 are used as indexes and passed as parameters recursively and to reverse!
, this way no allocation happens in heap, and at max log2(length(arr))
tuples of paramters will be stacked in the stack memory.
Also, @louisponet, by what I have seen in permute!
implementation (https://github.com/JuliaLang/julia/blob/ac8c5b70514fc0bd743f4e5df6a13dccdaccddb9/base/combinatorics.jl#L97) it seems to enforce one-based indexing. Unfortunately your effort to make the code index-agnostic is not sufficient.
The code used for the tests is below:
using BenchmarkTools
sep(f::Function, A::AbstractVector) = filter(f,A), filter(!f,A)
# does not work for vectors with different index types
function separate(f::Function, A::AbstractVector)
la = length(A)
ix = Vector{Int}(undef, la)
true_id = 1
false_id = 0
for (i, a) in enumerate(A)
if f(a)
ix[true_id] = i
true_id += 1
else
ix[end - false_id] = i
false_id += 1
end
end
reverse!(ix, true_id, la)
return permute!(A, ix), true_id
end
# does not work for vectors with different index types
function separate2(f :: Function, A :: AbstractVector, B :: AbstractVector)
i, j = 0, length(A) + 1
for (k, a) in enumerate(A)
if f(a)
B[i += 1] = k
else
B[j -= 1] = k
end
end
reverse!(B, i, length(A))
return Base.permute!!(A, B), i
end
# does not work for vectors with different index types
function sep3(f::Function, A::AbstractVector)
i, j = 1, length(A)
@inbounds while i < j
if f(A[i]) # i element at right position
if f(A[j]) # j element at wrong position
i += 1
while i < j && f(A[i]); i += 1; end
if i < j
A[i], A[j] = A[j], A[i]
i += 1
j -= 1
else
@assert i == j
@views return A[1:j], A[(j+1):length(A)]
end
else # both elements at right position
i += 1
j -= 1
end
else
if f(A[j]) # i at wrong possition but j at right position
A[i], A[j] = A[j], A[i]
i += 1
j -= 1
else # both alements at wrong position
j -= 1
while j > i && !f(A[j]); j -= 1; end
if j > i
A[i], A[j] = A[j], A[i]
i += 1
j -= 1
else
@assert i == j
@views return A[1:(i-1)], A[i:length(A)]
end
end
end
end
if i == j
if f(A[i])
@views return A[1:i], A[(i+1):length(A)]
else
@views return A[1:(i-1)], A[i:length(A)]
end
else
@assert i == j + 1
@views return A[1:j], A[i:length(A)]
end
end
function partition!(pred, arr)
if length(arr) === 0
return 0
end
# Base case
if length(arr) === 1
if pred(arr[1])
return 1
else
return 0
end
end
# Recursively partition upper and lower halves
middle = length(arr) >> 1
lower = partition!(pred, @view arr[1:middle])
upper = partition!(pred, @view arr[middle+1:end])
# Swap upper block of falses and lower block of trues
reverse!(@view arr[lower+1:middle])
reverse!(@view arr[middle+1:middle+upper])
reverse!(@view arr[lower+1:middle+upper])
# Return total number of trues
return lower + upper
end
function run_bench()
N = 10^3
b = rand(N)
a = deepcopy(b)
@btime sep(x -> x > 0.5, $a)
a = deepcopy(b)
@btime separate(x -> x > 0.5, $a)
a = deepcopy(b)
B = Vector{Int64}(undef, length(a))
@btime separate2(x -> x > 0.5, $a, $B)
a = deepcopy(b)
@btime sep3(x -> x > 0.5, $a)
a = deepcopy(b)
@btime partition!(x -> x > 0.5, $a)
end
run_bench()