I have a need for huge numbers of Bernoulli samples, so I tried to do better than the simple implementation in Distributions.jl (which just does `rand() < p`

), using a binary arithmetic decoder.

Some timing results.

```
julia> using BenchmarkTools, Distributions
julia> p = 0.99;
julia> d = Bernoulli(p);
julia> d′ = BernoulliBitStream(p);
julia> @btime rand($d); # legacy
9.961 ns (0 allocations: 0 bytes)
julia> @btime rand($d′); # new
6.943 ns (0 allocations: 0 bytes)
```

For uniform bits (`p==0.5`

), my implementation is slightly slower.

```
julia> d′ = BernoulliBitStream(0.5);
julia> @btime rand($d′);
10.867 ns (0 allocations: 0 bytes)
```

The Distributions.jl sampler requires 52 bits of entropy per sample, whereas mine requires less than one bit per sample, plus a few simple operations. Although this implementation helps me a little bit, I’m surprised it’s not faster. The Julia RNG must really be super well optimised.

Here’s the code, if anyone’s interested. If there are any obvious ways I can speed this up, please do point them out, thanks!

```
import Base.rand
# Produce uniform random bits, one at a time.
mutable struct BitStream
x::UInt64 # 64 bits of entropy at a time
cnt::Int # counter
BitStream() = new(rand(UInt64), 1)
end
function rand(bs::BitStream)
b = Bool(bs.x & 1)
bs.x >>>= 1
bs.cnt += 1
if bs.cnt > 64
bs.x = rand(UInt64)
bs.cnt = 1
end
b
end
# Produce Bernoulli random bits from unbiased source.
# Requires only h₂(p) input bits per output bit (h₂ is binary entropy function).
mutable struct BernoulliBitStream
p::Float64
bs::BitStream #
lo::Float64
hi::Float64
α::Float64 # precompute 1/p
β::Float64 # precompute 1/(1-p)
BernoulliBitStream(p) = new(p, BitStream(), 0.0, 1.0, 1/p, 1/(1-p))
end
# Binary arithmetic decoder
function rand(bbs::BernoulliBitStream)
p = bbs.p
while true
if bbs.lo >= p
bbs.lo = (bbs.lo - p) * bbs.β
bbs.hi = (bbs.hi - p) * bbs.β
return false
elseif bbs.hi <= p
bbs.lo *= bbs.α
bbs.hi *= bbs.α
return true
else
mid = (bbs.lo + bbs.hi) / 2
if rand(bbs.bs) # get 1 bit of entropy
bbs.lo = mid
else
bbs.hi = mid
end
end
end
end
```