Allocation-free weighted samples

Given an SVector of weights, is there a way to take a single weighted sample without any allocations? Ie, these work great but require an allocation:

x = [1]
wts = @SVector rand(16)
@btime wsample(1:16, $wts)
@btime wsample!(1:16, $wts, $x)[1]
julia> @btime wsample(1:16, $wts)
  52.173 ns (1 allocation: 144 bytes)
13

julia> @btime wsample!(1:16, $wts, $x)[1]
  62.325 ns (1 allocation: 144 bytes)
3

In my simple test, using staticarray increases allocation? I don’t know why.

julia> @btime wsample(1:16, $wts)
  27.587 ns (1 allocation: 144 bytes)
9
julia> wt = rand(16)
julia> @btime wsample(1:16, $wt)
  29.242 ns (1 allocation: 32 bytes)
6

I think it’s the instantiation of a mutable Weights instance

julia> @btime StatsBase.Weights($wts);
  14.358 ns (1 allocation: 144 bytes)

julia> ismutabletype(Weights) # usually uses 8 bytes for pointer
true

julia> sizeof(typeof( StatsBase.Weights(wts))) # 8+136=144
136

A bit of piracy can elide the instantiation if already provided a preallocated instance, but I don’t know if this is sound

julia> StatsBase.weights(w::Weights) = w

julia> @btime wsample(1:16, $(Weights(wts)))
  37.277 ns (0 allocations: 0 bytes)
6
2 Likes

Thanks, can you explain how this works?

Edit:
I quickly tried it in my function but it had no effect on the allocations. Likely I’m doing it wrong. Eg, this still allocates once per iteration:

using Accessors, StaticArrays, StatsBase, BenchmarkTools
StatsBase.weights(w::Weights) = w
function plusone(wts, res)
    for i in 1:100
        idx = wsample(Weights(wts))
        @reset res[idx] += 1
    end
        return res
    end

wts = @SVector rand(16)
res = @SVector fill(0, 16)
@btime plusone($wts, $res)

If you look into the source code, you’ll see that wsample(w) calls sample(default_rng(), weights(w)), where weights(w) creates an instance of the mutable Weights. As @Benny pointed out, this allocates.

Relevant part of sampling.jl: lines 1044-1056
"""
    wsample([rng], [a], w)

Select a weighted random sample of size 1 from `a` with probabilities proportional
to the weights given in `w`. If `a` is not present, select a random weight from `w`.

Optionally specify a random number generator `rng` as the first argument
(defaults to `Random.$(VERSION < v"1.3" ? "GLOBAL_RNG" : "default_rng()")`).
"""
wsample(rng::AbstractRNG, w::AbstractVector{<:Real}) = sample(rng, weights(w))
wsample(w::AbstractVector{<:Real}) = wsample(default_rng(), w)
wsample(rng::AbstractRNG, a::AbstractArray, w::AbstractVector{<:Real}) = sample(rng, a, weights(w))
wsample(a::AbstractArray, w::AbstractVector{<:Real}) = wsample(default_rng(), a, w)
Relevant part of weights.jl: lines 4-23; 69; 82-89
"""
    @weights name

Generates a new generic weight type with specified `name`, which subtypes `AbstractWeights`
and stores the `values` (`V<:AbstractVector{<:Real}`) and `sum` (`S<:Real`).
"""
macro weights(name)
    return quote
        mutable struct $name{S<:Real, T<:Real, V<:AbstractVector{T}} <: AbstractWeights{S, T, V}
            values::V
            sum::S
            function $(esc(name)){S, T, V}(values, sum) where {S<:Real, T<:Real, V<:AbstractVector{T}}
                isfinite(sum) || throw(ArgumentError("weights cannot contain Inf or NaN values"))
                return new{S, T, V}(values, sum)
            end
        end
        $(esc(name))(values::AbstractVector{T}, sum::S) where {S<:Real, T<:Real} = $(esc(name)){S, T, typeof(values)}(values, sum)
        $(esc(name))(values::AbstractVector{<:Real}) = $(esc(name))(values, sum(values))
    end
end

@weights Weights

"""
    weights(vs::AbstractArray{<:Real})

Construct a `Weights` vector from array `vs`.
See the documentation for [`Weights`](@ref) for more details.
"""
weights(vs::AbstractArray{<:Real}) = Weights(vec(vs))
weights(vs::AbstractVector{<:Real}) = Weights(vs)

By using StatsBase.weights(w::Weights) = w and directly supplying a Weights to wsample, you skip this instantiation. Note that creating our Weights instance still allocates. But wsample(w) does not if w is Weights .

So to improve your code snippet, you just need to move the Weights(wts) outside of the loop:

using Accessors, StaticArrays, StatsBase, BenchmarkTools
StatsBase.weights(w::Weights) = w
function plusone(wghts, res)
    for i in 1:100
        idx = wsample(wghts)
        @reset res[idx] += 1
    end
        return res
    end

wts = @SVector rand(16)
wghts = Weights(wts)
res = @SVector fill(0, 16)
@btime plusone($wghts, $res);
    #  2.400 μs (0 allocations: 0 bytes)

The reason why @Benny had no allocations in his @btime wsample(1:16, $(Weights(wts))) is because of the interpolation using $. For example,

julia> @btime wsample(1:16, $(Weights(wts)));
  18.737 ns (0 allocations: 0 bytes)

julia> @btime wsample(1:16, Weights($wts));
  30.452 ns (1 allocation: 144 bytes)

julia> @btime $(rand(10^6));
  2.200 ns (0 allocations: 0 bytes)
2 Likes

In Weights(w) we store w in the values field of the newly created Weights instance. The difference in allocation size is then the difference between the size of an SVector and a Vector. In the former case, we store the entries (128B), in the latter only a pointer (8B).
I’m not completely sure about the details, but I assume the rest of the difference has to do with memory alignment.

Thanks, the weights must be updated every iteration though. If i understand, using this method at some point I need to call Weights to use wsample within the loop. Ie, I can’t just pre-allocate a vector to store the weights then update that.

Assuming the type (including length for an SVector) of the weights does not change, you could exploit the mutability of Weights to update it in-place:

using Accessors, StaticArrays, StatsBase, BenchmarkTools
StatsBase.weights(w::Weights) = w
function plusone(wghts, res)
    for i in 1:100
        # Update in any way compatible with typeof(wghts).
        # Here this is Weights{Float64, Float64, SVector{16, Float64}}.
        # (The parameters are the types of the sum, the entries, and the weights vector.)
        wghts.values = @SVector rand(length(wghts.values))  
        wghts.sum = sum(wghts.values) 
        idx = wsample(wghts)
        @reset res[idx] += 1
    end
   return res
end

wts = @SVector rand(16)
wghts = Weights(wts)  # In the example wts is not used directly, but fixes the type
res = @SVector fill(0, 16)
@btime plusone($wghts, $res);
    # 5.250 μs (0 allocations: 0 bytes)

To avoid forgetting to update wghts.sum, you could also use a function

function update!(w::Weights{S, T, V}, new_wts::V) where {S, T, V}
    w.values = new_wts
    w.sum = sum(w.values)
end
2 Likes

Maybe an approach with reservoir sampling could fit your use case:

julia> using StreamSampling, BenchmarkTools

julia> function update_sample!(s, iter)
           for x in iter 
               update!(s, x, wts(x))
           end
           return s
       end;

julia> wts(x) = x;

julia> sample = ReservoirSample(Int, algAExpJ);

julia> iter = 1:16;

julia> @btime empty!(update_sample!($sample, $iter));
  62.166 ns (0 allocations: 0 bytes)

here you use a function for the weights (which can also point to an array if needed)

1 Like

I see, I was misusing @reset:

StatsBase.weights(w::Weights) = w
function plusone(wghts, res)
    for i in 1:100
        @reset wghts.values = @SVector rand(16)
        @reset wghts.sum    = sum(wghts.values)
        idx = wsample(wghts)
        @reset res[idx] += 1
    end
    return res
end


wts = @SVector rand(16);
wghts = Weights(wts);
res = @SVector fill(0, 16);
@btime plusone($wghts, $res);

Thanks, seems this will work perfectly.

I assume this would be a better solution given no type piracy though. I’ll try it out.

I’m not actually sure why Weights is mutable. I know StatsBase started early in Julia’s history so it may have made a Weights wrapper assuming mutable AbstractArrays with variable sums stored in a mutable sum field, but support for immutable arrays seems justified. I can’t imagine it but if there’s an internal reason it needs separate mutable instances, then that type piracy can get dangerous.

1 Like

I suppose it depends on the use case, but you could also write a simple non-allocating sampler along these lines

function mysample(w)
    r = rand() * sum(w)
    for (i, x) in pairs(w)
         r -= x
         r < 0 && return i
    end
end

Note that if you only update a small number of the weights between samples, there may be faster alternatives.

1 Like