Best way to alternate RNG in inner functions?

Given a code in which rand is called in a inner function, as in this simplified example:

module Test

import Random
export outer

RNG = Random.Xoshiro()

function inner(x)
     x += rand(RNG)
end

function outer(x)
    x += inner(x)
end

end

Which do you think is the most idiomatic way to allow RNG choice without causing type instabilities?

One possibility is to pass RNG as a parameter on every interface, of course. Do you see a better alternative? (without the const RNG the call to outer is obviously type-unstable).

The purpose of this is be able to use StableRNGs to test some functions with reproducible random states.

I have used an rng keyword argument for exactly this purpose,

module Test

import Random
export outer

RNG = Random.Xoshiro()

function inner(x; rng=RNG)
     x = rand(rng)
end

function outer(x; rng=RNG)
    inner(x; rng)
end

end

which I guess is what you mean by “pass RNG as a parameter on every interface”?

I did this more as a matter of avoiding globaI state than a worry about type stability. I’m surprised though that having a non-const RNG would cause type instability. Doesn’t rand have a known output type? What if you wrote your own rand wrapper; something like

RNG = Random.Xoshiro()

function set_rng(rng)
    global RNG
    RNG = rng
end

function myrand()::Float64
    global RNG
    return rand(RNG)
end
1 Like

indeed

It does, and you can even specify it:

julia> rand(Random.Xoshiro(), Int)
5009175639301388561

but using that or annotating the output is not solving the problem (the annotation prevents the instability from propagating only):

julia> RNG = Random.Xoshiro()
       function test(x::T) where {T}
           x + rand(RNG, T)::T
       end
test (generic function with 1 method)

julia> @btime test(0.0)
  116.872 ns (1 allocation: 16 bytes)
0.2989532433595429

julia> function test2(x::T, RNG) where {T}
           x = rand(RNG, T)
       end
test2 (generic function with 1 method)

julia> @btime test2(0.0, $RNG)
  3.010 ns (0 allocations: 0 bytes)
0.6171378061283105

I can live with passing RNG on every interface, but this is a case where I would be fine with a global state.