Why is this code slow?

I am trying to draw a random index from a vector of probabilities. Here is my code:

import Random
using Random: AbstractRNG

struct FastCategorical
    function FastCategorical(pdf)
        cdf = cumsum(pdf)
        @assert last(cdf) ≈ 1

@inline function Random.rand(rng::AbstractRNG, o::FastCategorical)
    p = rand(rng)
    searchsortedfirst(o._cdf, p)

using BenchmarkTools
using LinearAlgebra
pdf = normalize!(rand(10^5), 1)
d = FastCategorical(pdf)
rng = Random.GLOBAL_RNG

@btime rand($rng, $d) #   128.708 ns (0 allocations: 0 bytes)
@btime searchsortedfirst($(d._cdf), x) setup=(x = rand()) #   39.083 ns (0 allocations: 0 bytes)
@btime rand($(rng)) #   2.865 ns (0 allocations: 0 bytes)

Observe that rand is much slower then running both its lines individually.

Apparently this problem has little to do with your structure of code

using BenchmarkTools, LinearAlgebra
function test()
    cdf = cumsum(normalize!(rand(10^5), 1))
    println("Summed times")
    @btime searchsortedfirst($cdf, rand())
    println("Single times")
    @btime searchsortedfirst($cdf, x) setup=(x=rand())
    @btime rand()


Summed times
  187.690 ns (0 allocations: 0 bytes)
Single times
  56.369 ns (0 allocations: 0 bytes)
  4.125 ns (0 allocations: 0 bytes)

:: Thanks to @giordano for making the testing rigorous

1 Like
julia> function test3()
           cdf = cumsum(normalize!(rand(10^5), 1))
           @btime searchsortedfirst($cdf, $(rand()))
           @btime searchsortedfirst($cdf, $f) 
           @btime searchsortedfirst($cdf, $(pi/10))
           @btime rand()
test3 (generic function with 1 method)

julia> test3()
  44.017 ns (0 allocations: 0 bytes)
  44.198 ns (0 allocations: 0 bytes)
  44.163 ns (0 allocations: 0 bytes)
  3.012 ns (0 allocations: 0 bytes)
1 Like

CF PSA: Microbenchmarks remember branch history

It is even worse here, since your cdf does not fit into L1. You can probably get a significant speedup by using a cache-oblivious layout of the implicit search tree, instead of a linearly ordered vector.

Unfortunately, I am not aware of any julia packages implementing search-optimized layout of sorted lists.


@giordano thanks! my mistake, pi/3 also must be interpolated,
but rand() was intentionally uninterpolated, in order for it to be calculated every time


@foobar_lv2 explanatory post is right on the money.

You can use the setup keyword to @btime in that case, see https://github.com/JuliaCI/BenchmarkTools.jl#quick-start

i was oversimplifying too much, and ended up making the test pointless, i edited according to your suggestions, thanks

Sidestepping the whole optimization question: if you want a lot of draws from a discrete/categorical distribution, it is worth building an alias table.

It is implemented in Distributions.jl.


But why is the @btime searchsortedfirst($cdf, x) setup=(x=rand()) fast? You tell me that the CPU learns to predict searchsortedfirst branches from the rand call??

Because you also need to set evals=1. This is imo a design flaw in the BenchmarkTools API, but currently that’s how it works.

julia> @benchmark searchsortedfirst($cdf, x) setup=(x=rand())
  memory estimate:  0 bytes
  allocs estimate:  0
  minimum time:     66.402 ns (0.00% GC)
  median time:      71.518 ns (0.00% GC)
  mean time:        71.189 ns (0.00% GC)
  maximum time:     152.288 ns (0.00% GC)
  samples:          10000
  evals/sample:     978

julia> @benchmark searchsortedfirst($cdf, x) setup=(x=rand()) evals=1
  memory estimate:  0 bytes
  allocs estimate:  0
  minimum time:     149.000 ns (0.00% GC)
  median time:      283.000 ns (0.00% GC)
  mean time:        297.651 ns (0.00% GC)
  maximum time:     11.499 μs (0.00% GC)
  samples:          10000
  evals/sample:     1

julia> g(a)=searchsorted(a, rand())
g (generic function with 2 methods)

julia> @benchmark g($cdf)
  memory estimate:  0 bytes
  allocs estimate:  0
  minimum time:     238.247 ns (0.00% GC)
  median time:      245.880 ns (0.00% GC)
  mean time:        249.613 ns (0.00% GC)
  maximum time:     5.811 μs (0.00% GC)
  samples:          10000
  evals/sample:     425