Hey there!
I come to you with another performance nerdsnipe, whose solution could have a big impact on the autodiff ecosystem (via SparseConnectivityTracer.jl and DifferentiationInterface.jl).
You may remember my first question on pseudo-set data structures for fast unions:
One interesting proposal from that discussion (thanks @Tarny_GG_Channie) is a lazy set union, containing only pointers to the two sets it unites. I want to see how fast we can make this approach.
The code below is a simple version, and here are some of my performance questions:
Which format should I choose for RecursiveSet?
an immutable struct where the children are Union{Nothing,RecursiveSet}
a self-referential mutable struct where the children are always RecursiveSet, possibly equal to the parent
I don’t necessarily need to collect at the end, only to iterate. Would that be faster?
Is there any point in using AbstractTrees.jl?
Implementation:
struct RecursiveSet{T<:Number}
s::Union{Nothing,Set{T}}
child1::Union{Nothing,RecursiveSet{T}}
child2::Union{Nothing,RecursiveSet{T}}
function RecursiveSet(s::Set{T}) where {T}
return new{T}(s, nothing, nothing)
end
function RecursiveSet(rs1::RecursiveSet{T}, rs2::RecursiveSet{T}) where {T}
return new{T}(nothing, rs1, rs2)
end
end
Base.eltype(::Type{RecursiveSet{T}}) where {T} = T
function Base.union(rs1::RecursiveSet{T}, rs2::RecursiveSet{T}) where {T}
return RecursiveSet(rs1, rs2)
end
function Base.collect(rs::RecursiveSet{T}) where {T}
accumulator = Set{T}()
collect_aux!(accumulator, rs)
return collect(accumulator)
end
function collect_aux!(accumulator::Set{T}, rs::RecursiveSet{T}) where {T}
if !isnothing(rs.s)
union!(accumulator, rs.s::Set{T})
else
collect_aux!(accumulator, rs.child1::RecursiveSet{T})
collect_aux!(accumulator, rs.child2::RecursiveSet{T})
end
return nothing
end
Benchmark:
using Chairmarks
init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]
@be init(1000) collect(reduce(union, _))
Will you not statically know how many different sets might be in your lazy union? If you are able to know it statically, I’d suggest just storing a tuple, i.e.
struct LazyUnion{T<:Number, N}
sets::NTuple{N, Set{T}}
end
but it depends a lot on the use-case.
I don’t think that’s what you want to benchmark here, since that’s mostly just going to end up measuring the time that union! takes, and if that’s what you want then there’s no point in using the lazy union.
My understanding here is that you have a problem where you need to construct a set union, and then do some operations with that union. If you take a lazy union, then the construction will be basically free, but the operations (i.e. checking if an item is a member of the set, and inserting new items) will be slower than if you did the union eagerly.
So what you need to do is figure out for the the sorts of workloads you’re interested in, if the time you spend constructing a eager union is so big that it offsets the more costly operations of a lazy union.
Can you expand on why you would assume self-referential is better?
On the one hand, child access becomes type-stable naturally, but on the other hand it is now a mutable struct instead of immutable
It’s always gonna be exactly two (the two children), unless the RecursiveSet represents an actual Set, in which case the children are meaningless.
But of course the recursive nature means that in the end, a RecursiveSet might be a union over any number of actual Sets.
Maybe I should have a different type for the actual Sets (leaves of the tree) than for the RecursiveSets (internal nodes)?
It is interesting to benchmark because I’m only materializing the union! at the very end, and not for every intermediate variable. In this example, if reduce proceeds in the normal order, we would have n intermediate computations: x_1, x_1 \cup x_2, (x_1 \cup x_2) \cup x_3, etc. The lazy approach only materializes x_1 \cup \dots \cup x_n, although I could make it more efficient by iterating instead of collecting.
The only operation I’m interested in at the end is iterating over the non-duplicated elements of my lazy set. Nothing else matters to me, not checking or inserting. So you’re right, union should be basically free but it’s the collect / iterate bit that I want to make fast. Does that make sense?
Indeed, BitVector is a very efficient option to represent sets. My goal here is to explore another option, one that could scale to really sparse sets in really high dimension.
In the end, our package SparseConnectivityTracer.jl will likely offer various types of sets, and the users will have to pick the best one for their specific application.
My point is that there is no “right size / sparsity”. We are developing a generic sparse autodiff library, so we want people to be able to use it with very small or very big models, very sparse or not so sparse.
For large scales and very sparse matrices, Set does better, and SortedVector (discussed in the other thread) even better, albeit on a different benchmark case.
Another thing is that we are not allowed to modify the sets in place, because in our framework they play the same role as numbers. Think of forward mode autodiff, but instead of dual numbers we’re propagating sets of indices, and replacing every x_1 + x_2 with s_1 \cup s_2
EDIT: The original post had a major correctness flaw in Base.union. This edit fixes that, but it doesn’t change the mutation of the input problem. Leaving the corrected code up for posterity
Here is one that uses a linked list with mutable nodes!
mutable struct ListSet{T<:Number}
s::Set{T}
next::Union{Nothing,ListSet{T}}
first::ListSet{T}
function ListSet(s::Set{T}) where {T}
n = new{T}(s, nothing)
n.first = n
return n
end
function ListSet(s::Set{T}, previous::ListSet{T}) where {T}
n = new{T}(s, nothing, get_first(previous))
previous.next = n
return n
end
end
function get_first(cs::ListSet{T}) where {T}
f = cs.first
if f === cs
return cs
end
return get_first(f)
end
Base.eltype(::Type{ListSet{T}}) where {T} = T
Base.length(cs::ListSet) = sum(length(s) for s in cs)
function Base.union(cs1::ListSet{T}, cs2::ListSet{T}) where {T}
cs1.next = get_first(cs2)
cs2.first = get_first(cs1)
return cs2
end
function Base.iterate(cs::ListSet{T}) where {T}
f = get_first(cs)
return (f.s, f.next)
end
function Base.iterate(::ListSet{T}, state) where {T}
if isnothing(state)
return nothing
end
return (state.s, state.next)
end
function Base.collect(cs::ListSet{T}) where {T}
return collect(Set{T}(Iterators.flatten(cs)))
end
Benchmarks:
julia> using Chairmarks
julia> init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]
init (generic function with 1 method)
julia> @be init(1000) collect(reduce(union, _))
Benchmark: 272 samples with 1 evaluation
min 103.500 ÎĽs (1064 allocs: 282.703 KiB)
median 145.350 ÎĽs (1064 allocs: 282.703 KiB)
mean 186.126 ÎĽs (1064 allocs: 282.703 KiB, 0.36% gc time)
max 7.984 ms (1064 allocs: 282.703 KiB, 96.75% gc time)
julia> init2(nb_sets) = [ListSet(Set(i)) for i in 1:nb_sets]
init2 (generic function with 1 method)
# I have to add init2 to the benchmark because `union` is mutating for my implementation.
julia> @be collect(reduce(union, init2(1000)))
Benchmark: 323 samples with 1 evaluation
min 164.800 ÎĽs (5009 allocs: 456.336 KiB)
median 183.600 ÎĽs (5009 allocs: 456.336 KiB)
mean 255.769 ÎĽs (5009 allocs: 456.336 KiB, 0.87% gc time)
max 10.229 ms (5009 allocs: 456.336 KiB, 96.69% gc time)
Wow, that is blazingly fast indeed! Unfortunately we cannot mutate either of the sets that we feed to union. In our setting, the sets play the same role as numbers, we compute with them, we store them in arrays but we never modify them in-place
But I think there is a way to make it work with non-mutable sets, by storing a tape. This is a suggestion from Brian Chen and it is starting to make a lot of sense
I’ve done some tests with the ListSet structure but I get results that I can’t interpret: it seems to me that instead of the union it only provides the last of the Sets involved.
Regarding this aspect, the use of BitVectors intrinsically precludes the mutation of the sets of which one wants to merge: the corresponding empty set (falses(10^6) for the universe taken into consideration) is used as an accumulator. and sets are just “read”.
Below are some comparisons and verification of the equivalence of the result.
I report them only to understand if they are significant with respect to the context of interest or are off track (in the sense that they do not measure the quantities of interest) or even wrong.
BitArray vs RecursiveSet
julia> function bvu(v2::BitArray, v1::Array)
for b in v1
@inbounds v2[b]=true
end
v2
end
bvu (generic function with 1 method)
julia> using BenchmarkTools
julia> init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]
init (generic function with 1 method)
julia> init10(nb_sets) = [RecursiveSet(Set(rand(1:10^6,10))) for i in 1:nb_sets]
init10 (generic function with 1 method)
julia> @btime collect(reduce(union, is)) setup=(is=$init(1000));
94.200 ÎĽs (1064 allocations: 282.70 KiB)
julia> @btime collect(reduce(union, is)) setup=(is=$init10(1000));
378.500 ÎĽs (1039 allocations: 399.38 KiB)
julia> x_i=[[i] for i in 1:1000];
julia> x_i10=[rand(1:10^6,10) for i in 1:1000];
julia> @btime findall(reduce((s,c)->bvu(s,c), x_i, init=$falses(10^6)));
13.800 ÎĽs (5 allocations: 130.16 KiB)
julia> @btime findall(reduce((s,c)->bvu(s,c), x_i10, init=$falses(10^6)));
56.500 ÎĽs (6 allocations: 200.02 KiB)
is=init10(1000)
x_i10=[[e.s...] for e in is]
julia> findall(reduce((s,c)->bvu(s,c), x_i10, init=falses(10^6))) == sort([e for e in collect(reduce(union, is))])
true
This is a good idea! In our sparsity detection setting we do know the universe_size ahead of time, so we could do the unions by iteratively setting some bits to true in a fixed size BitVector. That’s the gist of your proposal, right?
The problem is that we cannot construct this global BitVector ahead of time. We have to construct it for each operation between numbers (aka sets).
To clarify the setting, imagine you have a vector-to-vector function f(x) = z. We want the sparsity pattern of the Jacobian of f, and we get it by replacing every number with a set of indices in the computation. The set of indices is the set i such that the component x_i influences said number at order 1. To apply your technique, we would need to create as many global BitVectors as components z_j of z.
Now imagine the function f goes from x to z by computing an intermediate vector y. When you work out a component y_k, you do not yet know for which components z_j it will play a role. So you don’t know which BitVector to modify inside z. So you need to allocate new BitVectors for every intermediate computation as well.
Thank you for the catch – I shouldn’t post midnight nerdsnipe results. Must test in the morning!
After edits, it is hard to compare apples-to-apples against the reference implementation since it mutates the input. If I add init to the reference benchmark, I get:
julia> init2(nb_sets) = [ListSet(Set(i)) for i in 1:nb_sets]
init2 (generic function with 1 method)
julia> @be collect(reduce(union, init2(1000)))
Benchmark: 342 samples with 1 evaluation
min 168.400 ÎĽs (5009 allocs: 456.336 KiB)
median 192.350 ÎĽs (5009 allocs: 456.336 KiB)
mean 270.844 ÎĽs (5009 allocs: 456.336 KiB, 0.83% gc time)
max 10.327 ms (5009 allocs: 456.336 KiB, 96.61% gc time)
julia> init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]
init (generic function with 1 method)
julia> @be collect(reduce(union, init(1000)))
Benchmark: 110 samples with 1 evaluation
min 334.400 ÎĽs (7065 allocs: 728.141 KiB)
median 536.900 ÎĽs (7065 allocs: 728.141 KiB)
mean 804.078 ÎĽs (7065 allocs: 728.141 KiB, 0.85% gc time)
max 13.014 ms (7065 allocs: 728.141 KiB, 93.63% gc time)
So ListSet has about 2000 fewer allocations.
Obviously – I recognize that mutating the input means that this solution doesn’t qualify, but hopefully it’s of some help!
Below the times of the prelude of the individual operations of unions and decoding.
If you have to do a lot of unions (unit cost about 18.0 ns (0 allocations: 0 bytes) ), it seems advantageous.
I tried to do better than findall(), but I couldn’t.
Maybe someone knows how to improve bv2idx().
function bv2idx(b0)
n=count(b0)
nc=length(b0.chunks)
idx=Vector{Int}(undef,n)
i=1
for c in 1:nc-1
b=b0.chunks[c]
if b!=0
s=(c-1)<<6
po=tz=trailing_zeros(b)+1
while b>0
idx[i]=s+po
b=b>>tz
i+=1
tz=trailing_zeros(b)+1
po+=tz
end
end
end
idx
end
Look at what base does for logical indexing. The main point is that your loop has trailing_zeros/ TZCNT plus a shift on the dependency path, while the standard loop (which Base uses) has replaced that by BLSR. The thing is latency-bound, not throughput-bound, i.e. most of your execution ports idle most of the time.
Per Agner Fog, TZCNT is 3 cycles latency on e.g. ice lake, as opposed to 1 cycle for BLSR.
The same picture is true on other intel/amd chips, and I think (?) also on ARM.
If this is the case, it doesn’t change much. But maybe I did it too quickly.
function bv2idx_bslr(b0)
n=count(b0)
nc=length(b0.chunks)
idx=Vector{Int}(undef,n)
i=1
for c in 1:nc-1
b=b0.chunks[c]
if b!=0
s=(c-1)<<6
tz=trailing_zeros(b)+1
while b>0
idx[i]=s+tz
b=b&(b-1)
i+=1
tz=trailing_zeros(b)+1
end
end
end
idx
end
julia> bv2idx_bslr(b0)==findall(b0)
true
julia> @btime b= bv2
idx_bslr(b0)
8.633 ÎĽs (1 allocation: 896 bytes)
100-element Vector{Int64}:
1901