Can I speed up this recursive set implementation?

Hey there!
I come to you with another performance nerdsnipe, whose solution could have a big impact on the autodiff ecosystem (via SparseConnectivityTracer.jl and DifferentiationInterface.jl).
You may remember my first question on pseudo-set data structures for fast unions:

One interesting proposal from that discussion (thanks @Tarny_GG_Channie) is a lazy set union, containing only pointers to the two sets it unites. I want to see how fast we can make this approach.

The code below is a simple version, and here are some of my performance questions:

  • Which format should I choose for RecursiveSet?
    • an immutable struct where the children are Union{Nothing,RecursiveSet}
    • a self-referential mutable struct where the children are always RecursiveSet, possibly equal to the parent
  • I don’t necessarily need to collect at the end, only to iterate. Would that be faster?
  • Is there any point in using AbstractTrees.jl?

Implementation:

struct RecursiveSet{T<:Number}
    s::Union{Nothing,Set{T}}
    child1::Union{Nothing,RecursiveSet{T}}
    child2::Union{Nothing,RecursiveSet{T}}

    function RecursiveSet(s::Set{T}) where {T}
        return new{T}(s, nothing, nothing)
    end

    function RecursiveSet(rs1::RecursiveSet{T}, rs2::RecursiveSet{T}) where {T}
        return new{T}(nothing, rs1, rs2)
    end
end

Base.eltype(::Type{RecursiveSet{T}}) where {T} = T

function Base.union(rs1::RecursiveSet{T}, rs2::RecursiveSet{T}) where {T}
    return RecursiveSet(rs1, rs2)
end

function Base.collect(rs::RecursiveSet{T}) where {T}
    accumulator = Set{T}()
    collect_aux!(accumulator, rs)
    return collect(accumulator)
end

function collect_aux!(accumulator::Set{T}, rs::RecursiveSet{T}) where {T}
    if !isnothing(rs.s)
        union!(accumulator, rs.s::Set{T})
    else
        collect_aux!(accumulator, rs.child1::RecursiveSet{T})
        collect_aux!(accumulator, rs.child2::RecursiveSet{T})
    end
    return nothing
end

Benchmark:

using Chairmarks

init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]

@be init(1000) collect(reduce(union, _))
1 Like

I would recommend self-referential, iterating, without abstractrees.

Will you not statically know how many different sets might be in your lazy union? If you are able to know it statically, I’d suggest just storing a tuple, i.e.

struct LazyUnion{T<:Number, N}
    sets::NTuple{N, Set{T}}
end

but it depends a lot on the use-case.


I don’t think that’s what you want to benchmark here, since that’s mostly just going to end up measuring the time that union! takes, and if that’s what you want then there’s no point in using the lazy union.

My understanding here is that you have a problem where you need to construct a set union, and then do some operations with that union. If you take a lazy union, then the construction will be basically free, but the operations (i.e. checking if an item is a member of the set, and inserting new items) will be slower than if you did the union eagerly.

So what you need to do is figure out for the the sorts of workloads you’re interested in, if the time you spend constructing a eager union is so big that it offsets the more costly operations of a lazy union.

Can you expand on why you would assume self-referential is better?
On the one hand, child access becomes type-stable naturally, but on the other hand it is now a mutable struct instead of immutable

It’s always gonna be exactly two (the two children), unless the RecursiveSet represents an actual Set, in which case the children are meaningless.
But of course the recursive nature means that in the end, a RecursiveSet might be a union over any number of actual Sets.
Maybe I should have a different type for the actual Sets (leaves of the tree) than for the RecursiveSets (internal nodes)?

It is interesting to benchmark because I’m only materializing the union! at the very end, and not for every intermediate variable. In this example, if reduce proceeds in the normal order, we would have n intermediate computations: x_1, x_1 \cup x_2, (x_1 \cup x_2) \cup x_3, etc. The lazy approach only materializes x_1 \cup \dots \cup x_n, although I could make it more efficient by iterating instead of collecting.

The only operation I’m interested in at the end is iterating over the non-duplicated elements of my lazy set. Nothing else matters to me, not checking or inserting. So you’re right, union should be basically free but it’s the collect / iterate bit that I want to make fast. Does that make sense?

I couldn’t get this script to work.
But I wonder if simply using BitVectors isn’t enough for your case.
What are the (3?)

universe_size = 10^6
set_size = 10^2
n_set_size=10^2

dimensions involved?

julia> using BenchmarkTools

julia> universe_size = 10^6
1000000

julia> set_size = 10^2
100

julia> x_i=[rand(1:universe_size, set_size) for _ in 1:n_set_size]
100-element Vector{Vector{Int64}}:
 [708682, 918664, 175331, 916558, 233886, 629813, 178439, 892554, 573764, 986587  …  
julia> function bvu(v1::Array, v2::BitArray=falses(10^6))
           for b in v1
               @inbounds v2[b]=true
           end
           v2
       end
bvu (generic function with 2 methods)

julia> bΦ=falses(10^6);

julia> @btime reduce((s,c)->bvu(c,s), x_i, init=$bΦ)
  6.760 ÎĽs (1 allocation: 16 bytes)
1000000-element BitVector:
 0
 0

In the end with findall (or other alternative) the union materializes.

julia> @btime findall(reduce((s,c)->bvu(c,s), x_i, init=$bΦ))
  48.500 ÎĽs (3 allocations: 77.81 KiB)
9951-element Vector{Int64}:
    182
    508
    707
    758
...

Sorry the closing parenthesis was missing, fixed now

Indeed, BitVector is a very efficient option to represent sets. My goal here is to explore another option, one that could scale to really sparse sets in really high dimension.

In the end, our package SparseConnectivityTracer.jl will likely offer various types of sets, and the users will have to pick the best one for their specific application.

Does the following example have the characteristics you are interested in? Sparsity and size?

julia> using BenchmarkTools

julia> universe_size = 10^7
10000000

julia> set_size = 10
10

julia> n_set_size=1000
1000

julia> x_i=[rand(1:universe_size, set_size) for _ in 1:n_set_size];

julia> function bvu(v2::BitArray, v1::Array)
           for b in v1
               @inbounds v2[b]=true
           end
           v2
       end
bvu (generic function with 1 method)

julia> bΦ=falses(10^7);

julia> @btime reduce((s,c)->bvu(s,c), x_i, init=$bΦ);
  10.800 ÎĽs (1 allocation: 16 bytes)

julia> @btime findall(reduce((s,c)->bvu(s,c), x_i, init=$bΦ));
  151.900 ÎĽs (3 allocations: 78.19 KiB)

I wasn’t able to clearly understand whether among the various proposals there is already one that does better than BitVector

except this approximation

julia> @btime [x_i ...; ]
  10.600 ÎĽs (3 allocations: 86.11 KiB)
10000-element Vector{Int64}:
 2571773
 1970931
 2889537
 5337319
  822413
 6896510

My point is that there is no “right size / sparsity”. We are developing a generic sparse autodiff library, so we want people to be able to use it with very small or very big models, very sparse or not so sparse.

For large scales and very sparse matrices, Set does better, and SortedVector (discussed in the other thread) even better, albeit on a different benchmark case.

Another thing is that we are not allowed to modify the sets in place, because in our framework they play the same role as numbers. Think of forward mode autodiff, but instead of dual numbers we’re propagating sets of indices, and replacing every x_1 + x_2 with s_1 \cup s_2

EDIT: The original post had a major correctness flaw in Base.union. This edit fixes that, but it doesn’t change the mutation of the input problem. Leaving the corrected code up for posterity :slight_smile:

Here is one that uses a linked list with mutable nodes!

mutable struct ListSet{T<:Number}
    s::Set{T}
    next::Union{Nothing,ListSet{T}}
    first::ListSet{T}

    function ListSet(s::Set{T}) where {T}
        n = new{T}(s, nothing)
        n.first = n
        return n
    end
    function ListSet(s::Set{T}, previous::ListSet{T}) where {T}
        n = new{T}(s, nothing, get_first(previous))
        previous.next = n
        return n
    end
end

function get_first(cs::ListSet{T}) where {T}
    f = cs.first
    if f === cs
        return cs
    end
    return get_first(f)
end

Base.eltype(::Type{ListSet{T}}) where {T} = T

Base.length(cs::ListSet) = sum(length(s) for s in cs)

function Base.union(cs1::ListSet{T}, cs2::ListSet{T}) where {T}
    cs1.next = get_first(cs2)
    cs2.first = get_first(cs1)
    return cs2
end

function Base.iterate(cs::ListSet{T}) where {T}
    f = get_first(cs)
    return (f.s, f.next)
end

function Base.iterate(::ListSet{T}, state) where {T}
    if isnothing(state)
        return nothing
    end
    return (state.s, state.next)
end

function Base.collect(cs::ListSet{T}) where {T}
    return collect(Set{T}(Iterators.flatten(cs)))
end

Benchmarks:

julia> using Chairmarks

julia> init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]
init (generic function with 1 method)

julia> @be init(1000) collect(reduce(union, _))
Benchmark: 272 samples with 1 evaluation
min    103.500 ÎĽs (1064 allocs: 282.703 KiB)
median 145.350 ÎĽs (1064 allocs: 282.703 KiB)
mean   186.126 ÎĽs (1064 allocs: 282.703 KiB, 0.36% gc time)
max    7.984 ms (1064 allocs: 282.703 KiB, 96.75% gc time)

julia> init2(nb_sets) = [ListSet(Set(i)) for i in 1:nb_sets]
init2 (generic function with 1 method)

# I have to add init2 to the benchmark because `union` is mutating for my implementation.
julia> @be collect(reduce(union, init2(1000)))
Benchmark: 323 samples with 1 evaluation
min    164.800 ÎĽs (5009 allocs: 456.336 KiB)
median 183.600 ÎĽs (5009 allocs: 456.336 KiB)
mean   255.769 ÎĽs (5009 allocs: 456.336 KiB, 0.87% gc time)
max    10.229 ms (5009 allocs: 456.336 KiB, 96.69% gc time)
1 Like

Wow, that is blazingly fast indeed! Unfortunately we cannot mutate either of the sets that we feed to union. In our setting, the sets play the same role as numbers, we compute with them, we store them in arrays but we never modify them in-place

But I think there is a way to make it work with non-mutable sets, by storing a tape. This is a suggestion from Brian Chen and it is starting to make a lot of sense

I’ve done some tests with the ListSet structure but I get results that I can’t interpret: it seems to me that instead of the union it only provides the last of the Sets involved.

Regarding this aspect, the use of BitVectors intrinsically precludes the mutation of the sets of which one wants to merge: the corresponding empty set (falses(10^6) for the universe taken into consideration) is used as an accumulator. and sets are just “read”.

Below are some comparisons and verification of the equivalence of the result.

I report them only to understand if they are significant with respect to the context of interest or are off track (in the sense that they do not measure the quantities of interest) or even wrong.

BitArray vs RecursiveSet

julia> function bvu(v2::BitArray, v1::Array)
           for b in v1
               @inbounds v2[b]=true
           end
           v2
       end
bvu (generic function with 1 method)

julia> using BenchmarkTools
julia> init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]
init (generic function with 1 method)

julia> init10(nb_sets) = [RecursiveSet(Set(rand(1:10^6,10))) for i in 1:nb_sets]
init10 (generic function with 1 method)

julia> @btime  collect(reduce(union, is)) setup=(is=$init(1000));
  94.200 ÎĽs (1064 allocations: 282.70 KiB)

julia> @btime  collect(reduce(union, is)) setup=(is=$init10(1000));
  378.500 ÎĽs (1039 allocations: 399.38 KiB)

julia> x_i=[[i] for i in 1:1000];

julia> x_i10=[rand(1:10^6,10) for i in 1:1000];

julia> @btime findall(reduce((s,c)->bvu(s,c), x_i, init=$falses(10^6)));       
  13.800 ÎĽs (5 allocations: 130.16 KiB)

julia> @btime findall(reduce((s,c)->bvu(s,c), x_i10, init=$falses(10^6)));     
  56.500 ÎĽs (6 allocations: 200.02 KiB)


is=init10(1000)
x_i10=[[e.s...] for e in is]

julia> findall(reduce((s,c)->bvu(s,c), x_i10, init=falses(10^6))) == sort([e for e in collect(reduce(union, is))])
true

This is a good idea! In our sparsity detection setting we do know the universe_size ahead of time, so we could do the unions by iteratively setting some bits to true in a fixed size BitVector. That’s the gist of your proposal, right?

The problem is that we cannot construct this global BitVector ahead of time. We have to construct it for each operation between numbers (aka sets).

Yeah this is a relevant benchmark

To clarify the setting, imagine you have a vector-to-vector function f(x) = z. We want the sparsity pattern of the Jacobian of f, and we get it by replacing every number with a set of indices in the computation. The set of indices is the set i such that the component x_i influences said number at order 1. To apply your technique, we would need to create as many global BitVectors as components z_j of z.

Now imagine the function f goes from x to z by computing an intermediate vector y. When you work out a component y_k, you do not yet know for which components z_j it will play a role. So you don’t know which BitVector to modify inside z. So you need to allocate new BitVectors for every intermediate computation as well.

Does that make sense?

Thank you for the catch – I shouldn’t post midnight nerdsnipe results. Must test in the morning!

After edits, it is hard to compare apples-to-apples against the reference implementation since it mutates the input. If I add init to the reference benchmark, I get:

julia> init2(nb_sets) = [ListSet(Set(i)) for i in 1:nb_sets]
init2 (generic function with 1 method)

julia> @be collect(reduce(union, init2(1000)))
Benchmark: 342 samples with 1 evaluation
min    168.400 ÎĽs (5009 allocs: 456.336 KiB)
median 192.350 ÎĽs (5009 allocs: 456.336 KiB)
mean   270.844 ÎĽs (5009 allocs: 456.336 KiB, 0.83% gc time)
max    10.327 ms (5009 allocs: 456.336 KiB, 96.61% gc time)

julia> init(nb_sets) = [RecursiveSet(Set(i)) for i in 1:nb_sets]
init (generic function with 1 method)

julia> @be  collect(reduce(union, init(1000)))
Benchmark: 110 samples with 1 evaluation
min    334.400 ÎĽs (7065 allocs: 728.141 KiB)
median 536.900 ÎĽs (7065 allocs: 728.141 KiB)
mean   804.078 ÎĽs (7065 allocs: 728.141 KiB, 0.85% gc time)
max    13.014 ms (7065 allocs: 728.141 KiB, 93.63% gc time)

So ListSet has about 2000 fewer allocations.

Obviously – I recognize that mutating the input means that this solution doesn’t qualify, but hopefully it’s of some help!

1 Like

Rigth!
It seemed to me applicable to this type of operations (which were, as I understood, the operations concerned)

\emptyset \cup x_1, (\emptyset \cup x_1) \cup x_2, ((\emptyset \cup x_1) \cup x_2) \cup x_3 , etc.

Below the times of the prelude of the individual operations of unions and decoding.
If you have to do a lot of unions (unit cost about 18.0 ns (0 allocations: 0 bytes) ), it seems advantageous.

bench bitvector
julia> using BenchmarkTools

julia> universe_size = 10^6
1000000

julia> set_size = 10
10

julia> function bvu(v2::BitArray, v1::Array)
           for b in v1
               @inbounds v2[b]=true
           end
           v2
       end
bvu (generic function with 1 method)

julia> A = rand(1:universe_size, set_size)
10-element Vector{Int64}:
 712148
 338682
 430007
 133129
 854854
 731234
 447168
 368404
 599018
 904474

julia> B = rand(1:universe_size, set_size)
10-element Vector{Int64}:
 584275
 686702
 529108
  60781
 660831
  70592
 239582
 643789
 640169
 991364

julia>        bE=falses(10^6);

julia>        bA=bvu(bE,A);

julia>        bB=bvu(bA, B);

julia>        findall(bB)
20-element Vector{Int64}:
  60781
  70592
 133129
 239582
 338682
 368404
 430007
 447168
 529108
 584275
 599018
 640169
 643789
 660831
 686702
 712148
 731234
 854854
 904474
 991364

julia> @btime bE=falses(10^6);
  3.471 ÎĽs (3 allocations: 122.20 KiB)

julia> @btime bA=bvu(bE,A);
  18.036 ns (0 allocations: 0 bytes)

julia> @btime bB=bvu(bA, B);
  17.954 ns (0 allocations: 0 bytes)

julia> @btime findall(bB)
  7.600 ÎĽs (1 allocation: 224 bytes)
20-element Vector{Int64}:
  60781
  70592
 133129
 239582
 338682
 368404
 430007
 447168
 529108
 584275
 599018
 640169
 643789
 660831
 686702
 712148
 731234
 854854
 904474
 991364

I tried to do better than findall(), but I couldn’t.
Maybe someone knows how to improve bv2idx().

function bv2idx(b0)
  n=count(b0)
  nc=length(b0.chunks)
  idx=Vector{Int}(undef,n)
  i=1
  for c in 1:nc-1
    b=b0.chunks[c]
    if b!=0
      s=(c-1)<<6
      po=tz=trailing_zeros(b)+1
      while b>0 
        idx[i]=s+po
        b=b>>tz
        i+=1
        tz=trailing_zeros(b)+1
        po+=tz
      end
    end
  end
  idx
end

julia> b0=falses(2^20)
1048576-element BitVector:
 0
 ...
 0

julia> a=rand(1:2^20, 10)
10-element Vector{Int64}:
 606786
...
 832505

julia> b0[a].=true
10-element view(::BitVector, [606786, 74995, 320773, 915622, 407162, 371510, 310114, 993256, 82023, 832505]) with eltype Bool:
 1
...
 1

julia> @btime b= bv2idx(b0);
  8.167 ÎĽs (1 allocation: 144 bytes)

julia> @btime f= findall(b0);
  7.875 ÎĽs (1 allocation: 144 bytes)

julia> bv2idx(b0)==findall(b0)
true

Look at what base does for logical indexing. The main point is that your loop has trailing_zeros/ TZCNT plus a shift on the dependency path, while the standard loop (which Base uses) has replaced that by BLSR. The thing is latency-bound, not throughput-bound, i.e. most of your execution ports idle most of the time.

Per Agner Fog, TZCNT is 3 cycles latency on e.g. ice lake, as opposed to 1 cycle for BLSR.

The same picture is true on other intel/amd chips, and I think (?) also on ARM.

1 Like

If this is the case, it doesn’t change much. But maybe I did it too quickly.

function bv2idx_bslr(b0)
  n=count(b0)
  nc=length(b0.chunks)
  idx=Vector{Int}(undef,n)
  i=1
  for c in 1:nc-1
    b=b0.chunks[c]
    if b!=0
      s=(c-1)<<6
      tz=trailing_zeros(b)+1
      while b>0 
        idx[i]=s+tz
        b=b&(b-1)
        i+=1
        tz=trailing_zeros(b)+1
      end
    end
  end
  idx
end

julia> bv2idx_bslr(b0)==findall(b0)
true
julia> @btime b= bv2
idx_bslr(b0)
  8.633 ÎĽs (1 allocation: 896 bytes)
100-element Vector{Int64}:
    1901