Counting number of times a RNG is called

Hi,

I’m trying to measure the number of random numbers generated by a Monte Carlo estimator for a given level of precision. Is there an easy way I could do this?

I would like to do something like:

rng = MersenneTwister(seed)
estimator(rng) # generates many random numbers
rng.count # returns how many random numbers were generated since rng was created

Best,
Michael

Unfortunately, I found nothing in the manual.

If you know which methods your estimator make use, and it does not restrict the type of its argument, I suppose you could create a new type wrapping MersenneTwister and define methods for it that just pass the arguments to MersenneTwister and increment a global (or a captured reference) counter by the number of generations made. However, this is a terrible solution (you can forget some method used for generating random numbers and count less, it depends in many assumptions, it takes considerable effort, etc…).

The other almost as bad solution is to look at the MersenneTwister source and check if one of the fields (probably idxF, or idxI, or some combination of them) gives what you want.

There is no direct way to do that, and more information is needed to find a solution. For example if you call only the scalar rand for Float64, you could overwrite the Random.set_mt_setfull!(r::MersenneTwister) method (which is called only when re-filling the cache of Float64 in a MT object) to also update a global counter, and at the end of your simulation multiply this counter with the size of the cache (plus adjusting with r.idxF).
But if you use also a rand method which produces arrays, it’s quite difficult to do what you want. You may need to overwrite the rand methods themselves (to update a counter), or create a wrapper of MersenneTwister which has a counter as a field (the cleanest solution).

4 Likes

This is the best solution IMO.

This would also allow making various details about the counting more precise, eg does rand(Float64, 10) count as 1 or 10 calls?

3 Likes

Thanks so much everyone for your suggestions. I would count rand(Float64, 10) as 10 calls in my case. I am passing rng to various samplers in the Distributions.jl package and I would ideally not want to look through the Distributions.jl source code to see how many times rng is used or what rand methods are used.

What would be the skeleton code for creating a wrapper of MersenneTwister so that it works with any methods that currently take a MersenneTwister rng?

For example, something along the lines of

mutable struct TweakedMersenneTwister
    count::Int
    rng::MersenneTwister
end

function Base.rand(t::TweakedMersenneTwister, i::Integer)
    t.count += i
    return rand(t.rng, i)
end

which gives

julia> rng = TweakedMersenneTwister(0, MersenneTwister());

julia> rng.count
0

julia> rand(rng, 3)
3-element Array{Float64,1}:
 0.6887358336000433
 0.2909818216770961
 0.23089019888929863

julia> rand(rng, 4)
4-element Array{Float64,1}:
 0.8570114504250239
 0.9896625712375964
 0.0658943236725813
 0.33129437214959623

julia> rng.count
7
3 Likes

Also, unless there is a specific reason to restrict to MersenneTwister, something like

mutable struct CountingRNG{T <: AbstractRNG}
    count::Int
    rng::T
end

should be more generic.

For fun, here’s a version that uses Cassette to count the number of calls to a particular method of rand

using Cassette, Random
import Cassette: overdub, recurse, @overdub
Cassette.@context CountCtx;

ctx = CountCtx(metadata=Ref(0))
const myrng = Random.GLOBAL_RNG # Change this for your RNG

"This method will be called at the very bottom"
function overdub(ctx::CountCtx, f::typeof(rand), rng::AbstractRNG, i::Random.UInt52)
    rng == myrng && (ctx.metadata[] += 1)
    recurse(ctx, f, rng, i)
end

overdub(ctx, randn, myrng)
julia> ctx.metadata[]
1

julia> overdub(ctx, randn, myrng, 5)

julia> ctx.metadata[]
6

This method just counts the number of times rand is invoked with a UInt52 as the second argument, since this is the method that is called at the very bottom to produce a random number.

This can be used in a nicer way like this

function testfun(N)
    for n = 1:N
        randn(myrng, n)
    end
end

julia> @overdub ctx testfun(5)

julia> ctx.metadata[]
21
5 Likes

For the MT, yes; in general, it need not be.

2 Likes

Here is a robust solution:

using Random
import Random: Sampler, Repetition, rand, rand!, CloseOpen01, SamplerSimple

mutable struct CountingRNG{T<:AbstractRNG} <: AbstractRNG
    count::Int
    rng::T
end

Sampler(::Type{CountingRNG{T}}, X, n::Repetition) where {T} = Sampler(T, X, n)
# disambiguate
Sampler(::Type{CountingRNG{T}}, ::Type{X}, n::Repetition) where {T,X} = Sampler(T, X, n)
Sampler(::Type{CountingRNG{T}}, ::Type{X}, n::Repetition) where {T,X<:AbstractFloat} = Sampler(T, X, n)

# intercept Float64 generation
Sampler(::Type{CountingRNG{T}}, ::Type{Float64}, n::Repetition) where {T} =
    SamplerSimple(CloseOpen01(Float64), Sampler(T, Float64, n))

rand(rng::CountingRNG, sp::Sampler) = rand(rng.rng, sp)
# this is to benefit from possible optimizations implemented for rng.rng,
# instead of using the default rand! (simple loop)
rand!(rng::CountingRNG, A::AbstractArray, sp::Sampler) = rand!(rng.rng, A, sp)

function rand(rng::CountingRNG, sp::SamplerSimple{CloseOpen01{Float64}})
    rng.count += 1
    rand(rng.rng, sp.data)
end

function rand!(rng::CountingRNG, A::AbstractArray{Float64}, sp::SamplerSimple{CloseOpen01{Float64}})
    rng.count += length(A)
    rand!(rng.rng, A, sp.data)
end

I didn’t test much, so there may be more method ambiguities to resolve.

1 Like

Nice. I am wondering if we could make this simpler when this part of the API is finalized (my understanding is that it is unofficial at the moment).

I don’t see what is unofficial at the moment (except maybe that typeof(SamplerSimple(x, data) <: SamplerSimple{typeof(x)}), but I agree that reading the docs once may not be enough to come up with this solution.
What could be simplified is deleting the rand! methods, but you may loose a bit of performance.

What I’m annoyed with comes from the ambiguities, which forces here to write redundant methods. For the next major version of Random (which can be breaking), let’s see how this can be improved (also, check if CloseOpen01 can be made un-official, in favor of Float64, to simplify things for the user).

I meant

The API is not clearly defined yet

in the docs.

Ah OK, so the docs are maybe a bit defensive to stay on the safe side. What they mean basically is that there is no official API allowing you to define rand on your custom RNG for only one type or two (e.g. UInt64) and get the definition for other types for free. But what is in my post above is officially supported (Sampler and rand are defined for all the types), and the docs should be updated to reflect that.

1 Like

Thanks so much for the code! I tried @rfourquet robust version but there is still some ambiguity when running randn(rng)

ERROR: MethodError: rand(::CountingRNG{MersenneTwister}, ::Random.SamplerTrivial{Random.UInt52{UInt64},UInt64}) is ambiguous. Candidates:
rand(r::AbstractRNG, ::Random.SamplerTrivial{Random.UInt52{UInt64},E} where E) in Random at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/generation.jl:119
rand(r::AbstractRNG, sp::Random.SamplerTrivial{#s623,E} where E where #s623<:Random.UniformBits{T}) where T in Random at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/generation.jl:122
rand(rng::CountingRNG, sp::Sampler) in Main at /Users/michaelfairley/Git-Projects/EVSI/src/counting_rng.jl:18
Possible fix, define
rand(::CountingRNG, ::Random.SamplerTrivial{Random.UInt52{UInt64},E} where E)
Stacktrace:
[1] rand(::CountingRNG{MersenneTwister}, ::Random.UInt52{UInt64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/Random.jl:219
[2] randn at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/normal.jl:38 [inlined]
[3] randn(::CountingRNG{MersenneTwister}, ::Type{Float64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/normal.jl:165
[4] randn! at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/normal.jl:171 [inlined]
[5] randn at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/normal.jl:182 [inlined]
[6] randn(::CountingRNG{MersenneTwister}, ::Int64) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/normal.jl:186
[7] top-level scope at none:0

Did you remember to initialize rng?

julia> rng = CountingRNG(0, MersenneTwister(1234));

julia> rand(rng)
0.5908446386657102

julia> rand(rng)
0.7667970365022592

julia> rng.count
2

Yes, the error is with rand**n**. Here’s all the input from a fresh REPL.

julia> using Random

julia> import Random: Sampler, Repetition, rand, rand!, CloseOpen01, SamplerSimple

julia> mutable struct CountingRNG{T<:AbstractRNG} <: AbstractRNG
           count::Int
           rng::T
       end

julia> Sampler(::Type{CountingRNG{T}}, X, n::Repetition) where {T} = Sampler(T, X, n)
Sampler

julia> # disambiguate
       Sampler(::Type{CountingRNG{T}}, ::Type{X}, n::Repetition) where {T,X} = Sampler(T, X, n)
Sampler

julia> Sampler(::Type{CountingRNG{T}}, ::Type{X}, n::Repetition) where {T,X<:AbstractFloat} = Sampler(T, X, n)
Sampler

julia> # intercept Float64 generation
       Sampler(::Type{CountingRNG{T}}, ::Type{Float64}, n::Repetition) where {T} =
           SamplerSimple(CloseOpen01(Float64), Sampler(T, Float64, n))
Sampler

julia> rand(rng::CountingRNG, sp::Sampler) = rand(rng.rng, sp)
rand (generic function with 62 methods)

julia> # this is to benefit from possible optimizations implemented for rng.rng,
       # instead of using the default rand! (simple loop)
       rand!(rng::CountingRNG, A::AbstractArray, sp::Sampler) = rand!(rng.rng, A, sp)
rand! (generic function with 48 methods)

julia> function rand(rng::CountingRNG, sp::SamplerSimple{CloseOpen01{Float64}})
           rng.count += 1
           rand(rng.rng, sp.data)
       end
rand (generic function with 63 methods)

julia> function rand!(rng::CountingRNG, A::AbstractArray{Float64}, sp::SamplerSimple{CloseOpen01{Float64}})
           rng.count += length(A)
           rand!(rng.rng, A, sp.data)
       end
rand! (generic function with 49 methods)

julia> rng = CountingRNG(0, MersenneTwister(1234));

julia> randn(rng)
ERROR: MethodError: rand(::CountingRNG{MersenneTwister}, ::Random.SamplerTrivial{Random.UInt52{UInt64},UInt64}) is ambiguous. Candidates:
  rand(r::AbstractRNG, ::Random.SamplerTrivial{Random.UInt52{UInt64},E} where E) in Random at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/generation.jl:119
  rand(r::AbstractRNG, sp::Random.SamplerTrivial{#s623,E} where E where #s623<:Random.UniformBits{T}) where T in Random at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/generation.jl:122
  rand(rng::CountingRNG, sp::Sampler) in Main at REPL[8]:1
Possible fix, define
  rand(::CountingRNG, ::Random.SamplerTrivial{Random.UInt52{UInt64},E} where E)
Stacktrace:
 [1] rand(::CountingRNG{MersenneTwister}, ::Random.UInt52{UInt64}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/Random.jl:219
 [2] randn(::CountingRNG{MersenneTwister}) at /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.1/Random/src/normal.jl:38
 [3] top-level scope at none:0

julia>

You can define the method suggested in the error message:

function rand(rng::CountingRNG, sp::Random.SamplerTrivial{Random.UInt52{UInt64}})
    # rng.count += 1
    rand(rng.rng, sp)
end

You have to decide what exactly you want to count: does randn(rng) should count as 1, or as whatever number of calls to rand(rng) this call consumes? Uncomment the line rng.count += 1 above for the latter (rand(UInt52()) actually doesn’t consume a Float64, but still the same number of bits as a Float64).

1 Like

Thanks! I count calls to rand(rng) that randn(rng) consumes. I don’t think I fully understand how the Random API works but it seems like I need to redefine all the Sampler and rand methods to work with CountingRNG. Is this approach robust in that if I forget to define a method then the call will produce an error and I won’t end up silently not counting some random number generation?

If you are only ever going to use the MersenneTwister, why not consider the Cassette approach? It did not require you to implement much, only find the lowest method that generates an actual random number and define a contextual dispatch for this. Cassette itself does not have any dependencies so it’s a rather lightweight solution.

2 Likes