Linear relationship between Xoshiro tasks

I’ve spent some time working on / thinking about this. First, whatever we do, we do really want some analogue of the dotmix / splitmix proof to work. Yes, it’s not sufficient, but it is necessary, and it’s something we can actually prove, whereas proving linear independence of arbitrarily many task states seems quite challenging. I believe I’ve managed to whittle down the requirements on a “dot product” accumulation function to the bare minimum.

One of the major problems with the dotmix proof (see p.5 section 3 here) is that it seems, on the face of it to require associativity in order to make the algebraic tricks work. But when it comes to RNGs, a sequence of associative operations is pretty suspect. Fortunately, I believe you can generalize the dotmix construction with any accumulator function with some very general requirements that I’ll give next.

We’ll use the same domain for our state and our weights, call it S. The state here is an accumulator for the “dotmix” construction, but in our application, it also happens to be the state of a single 64-bit register of the main Xoshiro256 RNG. We want to apply an accumulation function to the parent state (which doesn’t change) to combine it with a pseudo-random weight, and get a new accumulator state (which is also used as part of the main RNG state). In our simplified binary dotmix/splitmix, the accumulator function is just +, which makes the state a dot product. As @foobar_lv2 has pointed out, this causes problems with correlated states when the result is also used as Xoshiro state. I’m going to show that we can change the accumulator to be something non-commutative and non-associative.

Let’s write u: S \times S \to S for the update function and we’ll write u(s, w) for the arguments. The requirement is that u must be a bijection with respect to each argument for any fixed value of the other argument. Or, written out:

  • For all w \in S: s \mapsto u(s, w) is a bijection on S
  • For all s \in S: w \mapsto u(s, w) is a bijection on S.

I’ll write the “pedigree” or task coordinates of a task as an infinite binary sequence where all but finitely many coordinates are zero, written as t = (t_1, t_2, ...) \in 2^\mathbb{N}; denote the subset of 2^\mathbb{N} with finitely many ones as T. The root task is (0, 0, 0, ...) and its first child is (1, 0, 0, ...) and its second child is (0, 1, 0, ...) and that task’s first child is (0, 1, 1, ...). A child’s coordinates differ by its parents coordinates in only one place—at the coordinate that indicates which child of the parent it is. The number of coordinates two tasks differ by is how far apart they are in the task tree (as a graph).

In order to define the “compression function” there are a few more notations. We have pseudo-random weights w = (w_1, w_2, ...) \in S^\mathbb{N}, which are common across tasks. We also have a function s_0 : 2^\mathbb{N} \to S assigning an initial state to each task. Unexpectedly, this can be completely arbitrary. Define C: 2^\mathbb{N} \to S as:

  • C_0(t) = s_0(t)
  • C_i(t) = \mathrm{ifelse}(t_i = 0, C_{i-1}(t), u(C_{i-1}(t), w_i))
  • C(t) = \lim_{i \to \infty} C_i(t)

This gives the state assigned to each task. In words:

  • Each task starts with its own arbitrary state value
  • For each i whether we apply the update function or not depends on t_i:
    • if t_i = 0 then we leave the state alone
    • if t_i = 1 then we apply the update function with w_i as second argument
  • Since tasks only have finitely many non-zero coordinates, this becomes constant after the last non-zero coordinate, so the limit is well-defined.

Ok, we’re done with notation stuff. Now to the meat. We want to show that for two distinct tasks t \ne t', the chance of C(t) = C(t') is 1/|S|. Let i be the last coordinate where t and t' differ. If C(t) = C(t') then we can conclude that C_i(t) = C_i(t') based on the fact that u(s, w) = u(s', w) implies s = s' since s \mapsto u(s, w) is a bijection. That leaves us with C_i(t) = C_i(t') and t_i \ne t'_i. Without loss of generality, let t_i = 0 and t'_i = 1. Applying the definition of C_i we have

C_{i-1}(t) = u(C_{i-1}(t'), w_i).

Now, disregard how C_{i-1} is produced and just let s = C_{i-1}(t) and s' = C_{i-1}(t'). This simplifies the above equation to

s = u(s', w_i).

We want this to have a 1/|S| chance, which is precisely what we get since w \mapsto u(s, w) is a bijection, meaning that each output value, s, happens for exactly one input value, w, which has a 1/|S| chance of being the value of w_i. This probability is the same regardless of the first argument, s', although which value is required changes if s' does.

What’s somewhat suprising about this result is that most of the computation of C(t) doesn’t matter at all. It starts with a completely arbitrary initial value, and all we really care about is that at the last point where they differ, the two tasks have some state values—same, different, doesn’t matter!—and then if we leave one alone but apply the update function to the other one with a random weight, it’s unlikely that they’ll end up the same after that, only happening if the random weight happens to be one specific value that depends on the two incoming state values, whatever they may be.

With that proof out of the way, what would be a good update function? We are no longer burdened with any need for commutativity, associativity, etc. Only that it’s a two-argument function S \times S \to S that is a bijection with respect to each argument. At this point I’m just going to say what I think is a good function and then justify it:

function update(s::UInt64, w::UInt64)
    s = bswap(s)
    2s*w + s + w
end

The first step is to show that this function is a bijection with repsect to both s and w. Note:

(2s + 1)(2w + 1) \div 2 = \\ (4sw + 2s + 2w + 1) \div 2 = \\ 2sw + s + w

Odd numbers like 2s+1 and 2w+1 have multiplicative inverses, modulo a power of two like 2^{64}, and the product is guaranteed to be odd, so shifting the last bit away doesn’t discard any information. Explicitly, the inverse of s \mapsto 2sw + s + w is

s' \mapsto (2s'+1)(2w+1)^{-1} \div 2

Substituting s' = u(s, w) to check:

(2\,u(s, w) + 1)(2w + 1)^{-1} \div 2 = \\ (2(2sw + s + w) + 1)(2w + 1)^{-1} \div 2 = \\ (4sw + 2s + 2w + 1)(2w + 1)^{-1} \div 2 = \\ (2s + 1)(2w + 1)(2w + 1)^{-1} \div 2 = \\ (2s + 1) \div 2 = s

Similarly for w. Of course, this is worth testing with a reduced state space:

julia> f(a::UInt8, b::UInt8) = 2a*b + a + b
f (generic function with 1 method)

julia> F = [f(a, b) for a=0x0:0xff, b=0x0:0xff];

julia> all(allunique(c) for c in eachcol(F))
true

Why include the bswap in update? It’s cheap and doesn’t commute nicely with any other algebraic operations, so it’s likely to help trip up any clean algebraic relationships. Multiplication also tends to have more chaotic effects on the high bits, so bswapping between “rounds” mixes around the bits that are most affected by the update function. It’s also applied to the state in parallel to the PCG output computation of the weight before combining them, so it’s likely to be free.

I’ve implemented this update function here (a commit in this PR):

Yes, none of this proves linear independence of N task states, but I rather doubt we can prove that analytically. We should probably test that for up to 256 tasks, we get 256 linearly independent states. Why 256? Because there are only 256 bits in Xoshiro256 states, so after that you’re guaranteed that some subset has a linear relationship.

5 Likes