Does Julia have any utility to "split" RNGs for later use in parallel?

I want to generate some random numbers across a series of parallel processes, with a single input RNG provided by the user. Schematically I want something like:

using Random

rng = Xoshiro(1) # provided by user

# generate 100 new unique RNGs deterministically from the 
# input rng, using a hypothetical `split_rng` function
rngs = split_rng(rng, 100) 

# run results in parallel
results = pmap(rngs) do rng
    x = rand(rng)
    # ...

This is pretty much how it works in Jax, which is where I’m inspired from, but maybe there’s a different intended way to do this in Julia? Or is there already something like split_rng?

I suppose one possible implementation of split_rng is something like:

split_rng(rng, N) = Xoshiro.(rand(rng, UInt32, N))

but I don’t really know anything about random numbers and/or how “principled” this is.

The closest equivalent I’ve seen to what JAX does is DiffEqNoiseProcess.jl/virtual_brownian_tree_interface.jl at c48cdce099cece1edbd8f99da960bc67e3c2c4ca · SciML/DiffEqNoiseProcess.jl · GitHub, which I learned of through Taking PRNGs seriously · Issue #9 · avik-pal/Lux.jl · GitHub. Not sure if a similar approach could be applied to Base RNG types, but at least Random123 has implementations of the ThreeFry and Philox RNGs that JAX uses. The default CUDA.jl RNG is also a Philox variant, so it may well be splittable too.

1 Like

The C++ implementation of Xoshiro has a jump method specifically to generate non-overlapping subsequences for parallel computations. I’m not sure offhand if the Julia implementation has something similar?


Thanks, I looked at the Julia source and didn’t find “jump” per-se, but did stumble on forkRand which seems to be doing a kind of simple thing which I suppose could be used to implement my split_rng as:

function split_rng(rng::Xoshiro, N)
    map(Random.XoshiroSimd.forkRand(rng, Val(N))...) do si...
        Xoshiro((s.value for s in si)...)

And thanks for the links @ToucheSir I couldn’t quiclky get an implementation of exactly the thing I wanted out of that but yea if those have the same RNGs Jax uses then surely its possible to split them the same too.

Anyway, I’m taking all this to mean that there’s no unified API for splitting (would be nice to have!) but that its possible to hack together something specific as needed.

1 Like

Rand123 definitely has a method for what you want. It’s called setcounter! I believe. You have to decide how far you want to go for each split, but it is otherwise straightforward.