# Why is this code slow?

I am trying to draw a random index from a vector of probabilities. Here is my code:

``````import Random
using Random: AbstractRNG

struct FastCategorical
_cdf::Vector{Float64}
function FastCategorical(pdf)
cdf = cumsum(pdf)
@assert last(cdf) ≈ 1
new(cdf)
end
end

@inline function Random.rand(rng::AbstractRNG, o::FastCategorical)
p = rand(rng)
searchsortedfirst(o._cdf, p)
end

using BenchmarkTools
using LinearAlgebra
pdf = normalize!(rand(10^5), 1)
d = FastCategorical(pdf)
rng = Random.GLOBAL_RNG

@btime rand(\$rng, \$d) #   128.708 ns (0 allocations: 0 bytes)
@btime searchsortedfirst(\$(d._cdf), x) setup=(x = rand()) #   39.083 ns (0 allocations: 0 bytes)
@btime rand(\$(rng)) #   2.865 ns (0 allocations: 0 bytes)
``````

Observe that `rand` is much slower then running both its lines individually.

Apparently this problem has little to do with your structure of code

``````using BenchmarkTools, LinearAlgebra
function test()
cdf = cumsum(normalize!(rand(10^5), 1))
println("Summed times")
@btime searchsortedfirst(\$cdf, rand())
println("Single times")
@btime searchsortedfirst(\$cdf, x) setup=(x=rand())
@btime rand()
end
test()

--------

Summed times
187.690 ns (0 allocations: 0 bytes)
Single times
56.369 ns (0 allocations: 0 bytes)
4.125 ns (0 allocations: 0 bytes)

``````

:: Thanks to @giordano for making the testing rigorous

1 Like
``````julia> function test3()
cdf = cumsum(normalize!(rand(10^5), 1))
f=rand();
@btime searchsortedfirst(\$cdf, \$(rand()))
@btime searchsortedfirst(\$cdf, \$f)
@btime searchsortedfirst(\$cdf, \$(pi/10))
@btime rand()
end
test3 (generic function with 1 method)

julia> test3()
44.017 ns (0 allocations: 0 bytes)
44.198 ns (0 allocations: 0 bytes)
44.163 ns (0 allocations: 0 bytes)
3.012 ns (0 allocations: 0 bytes)
``````
1 Like

It is even worse here, since your cdf does not fit into L1. You can probably get a significant speedup by using a cache-oblivious layout of the implicit search tree, instead of a linearly ordered vector.

Unfortunately, I am not aware of any julia packages implementing search-optimized layout of sorted lists.

2 Likes

@giordano thanks! my mistake, pi/3 also must be interpolated,
but rand() was intentionally uninterpolated, in order for it to be calculated every time

Edited:

@foobar_lv2 explanatory post is right on the money.

You can use the `setup` keyword to `@btime` in that case, see https://github.com/JuliaCI/BenchmarkTools.jl#quick-start

i was oversimplifying too much, and ended up making the test pointless, i edited according to your suggestions, thanks

Sidestepping the whole optimization question: if you want a lot of draws from a discrete/categorical distribution, it is worth building an alias table.

It is implemented in Distributions.jl.

2 Likes

But why is the `@btime searchsortedfirst(\$cdf, x) setup=(x=rand())` fast? You tell me that the CPU learns to predict `searchsortedfirst` branches from the `rand` call??

Because you also need to set `evals=1`. This is imo a design flaw in the BenchmarkTools API, but currently that’s how it works.

``````julia> @benchmark searchsortedfirst(\$cdf, x) setup=(x=rand())
BenchmarkTools.Trial:
memory estimate:  0 bytes
allocs estimate:  0
--------------
minimum time:     66.402 ns (0.00% GC)
median time:      71.518 ns (0.00% GC)
mean time:        71.189 ns (0.00% GC)
maximum time:     152.288 ns (0.00% GC)
--------------
samples:          10000
evals/sample:     978

julia> @benchmark searchsortedfirst(\$cdf, x) setup=(x=rand()) evals=1
BenchmarkTools.Trial:
memory estimate:  0 bytes
allocs estimate:  0
--------------
minimum time:     149.000 ns (0.00% GC)
median time:      283.000 ns (0.00% GC)
mean time:        297.651 ns (0.00% GC)
maximum time:     11.499 μs (0.00% GC)
--------------
samples:          10000
evals/sample:     1

julia> g(a)=searchsorted(a, rand())
g (generic function with 2 methods)

julia> @benchmark g(\$cdf)
BenchmarkTools.Trial:
memory estimate:  0 bytes
allocs estimate:  0
--------------
minimum time:     238.247 ns (0.00% GC)
median time:      245.880 ns (0.00% GC)
mean time:        249.613 ns (0.00% GC)
maximum time:     5.811 μs (0.00% GC)
--------------
samples:          10000
evals/sample:     425
``````
2 Likes