Faster Bernoulli sampling

I have a need for huge numbers of Bernoulli samples, so I tried to do better than the simple implementation in Distributions.jl (which just does rand() < p), using a binary arithmetic decoder.

Some timing results.

julia> using BenchmarkTools, Distributions
julia> p = 0.99;
julia> d = Bernoulli(p);
julia> d′ = BernoulliBitStream(p);
julia> @btime rand($d);     # legacy
  9.961 ns (0 allocations: 0 bytes)
julia> @btime rand($d′);     # new
  6.943 ns (0 allocations: 0 bytes)

For uniform bits (p==0.5), my implementation is slightly slower.

julia> d′ = BernoulliBitStream(0.5);
julia> @btime rand($d′);
  10.867 ns (0 allocations: 0 bytes)

The Distributions.jl sampler requires 52 bits of entropy per sample, whereas mine requires less than one bit per sample, plus a few simple operations. Although this implementation helps me a little bit, I’m surprised it’s not faster. The Julia RNG must really be super well optimised.

Here’s the code, if anyone’s interested. If there are any obvious ways I can speed this up, please do point them out, thanks!

import Base.rand

# Produce uniform random bits, one at a time.
mutable struct BitStream
    x::UInt64   # 64 bits of entropy at a time
    cnt::Int    # counter
    BitStream() = new(rand(UInt64), 1)
end

function rand(bs::BitStream)
    b = Bool(bs.x & 1)
    bs.x >>>= 1
    bs.cnt += 1
    if bs.cnt > 64
        bs.x = rand(UInt64)
        bs.cnt = 1
    end
    b
end

# Produce Bernoulli random bits from unbiased source.
# Requires only h₂(p) input bits per output bit (h₂ is binary entropy function).
mutable struct BernoulliBitStream
    p::Float64
    bs::BitStream  #
    lo::Float64
    hi::Float64
    α::Float64     # precompute 1/p
    β::Float64     # precompute 1/(1-p)
    BernoulliBitStream(p) = new(p, BitStream(), 0.0, 1.0, 1/p, 1/(1-p))
end

# Binary arithmetic decoder
function rand(bbs::BernoulliBitStream)
    p = bbs.p
    while true
        if bbs.lo >= p
            bbs.lo = (bbs.lo - p) * bbs.β
            bbs.hi = (bbs.hi - p) * bbs.β
            return false
        elseif bbs.hi <= p
            bbs.lo *= bbs.α
            bbs.hi *= bbs.α
            return true
        else
            mid = (bbs.lo + bbs.hi) / 2
            if rand(bbs.bs)  # get 1 bit of entropy
                bbs.lo = mid
            else
                bbs.hi = mid
            end
        end
    end
end
1 Like

I am not sure about this — I think it requires multiple bits on average, especially for p that is far from 0.5. Also, the “few simple operations” are costly too. In any case, the first thing I would check is

https://docs.julialang.org/en/v1/manual/performance-tips/

especially profiling and benchmarking parts. See also

which I find useful for this kind of benchmarks.

Incidentally, are you aware of Random.bitrand? For p = 1/2, it is hard to beat.

2 Likes

What you do to consume only one bit of entropy at a time has been tried on MersenneTwister without success: it generates natively (2 times) 52 bits of entropy at once, very fast, so even simple operations to save individual bits for later are more costly than generating 52 more bits. Note that this is very specific to this particular RNG, and could very well change as soon as we get a new default RNG (which seems to be on its way :slight_smile: ). So IIRC, even rand(Bool) just take 1 out of 52 bits and discards the others (in the scalar case, there are other possible optimization when generating random arrays).

So one possible way (if you don’t mind using internals) for your example to try to get slighly faster is to use rand(Random.UInt52Raw()) and consume 52 bits one by one rather than rand(UInt64), as the latter internally consumes actually 104 bits of entropy (2*52) and not 64.

2 Likes

It might be better to just use another package, and get 64-bits.

Possibly:
https://sunoru.github.io/RandomNumbers.jl/stable/man/benchmark/

I was going to be helpful and find the best code to use, but actually I couldn’t even confirm the speed increase for Sonuru’s implementation vs. built-in RNG. For (there’s even better out there):

r = Xoroshiro128Plus(0x1234567890abcdef)

1 Like

Out of curiosity, what is your application for this?

2 Likes

Interesting, that’s really helpful and supports my findings.
I was always under the misconception that RNG was quite computationally intensive, so that’s something I learned today!

Stochastic computing simulations.

Need is too strong a word, it would just be nice to have faster sims, and I was surprised that my strategy didn’t provide that.

Actually (i don’t know what i was thinking) rand(Uint64) doesn’t consume 104 bits of entropy anymore, because MersenneTwister stores an array of integers in order to benefit from optimization of array generation. So much less bits are wasted. And IIRC, rand(Bool) uses rand(UIn32) internally (using less bits than 32 didn’t seem to improve performance).

1 Like