Iterate over the integer gridpoints in a simplex

I would like to iterate over all x::NTuple{N,Int} such that all(x .>= 0) && sum(x) == K (ie integer gridpoints in a simplex). For small problems a brute force approach

function simplex_integer_grid(::Val{N}, K) where N
    z = 0:K
    [ι for ι in Iterators.product(ntuple(_ -> z, Val(N))...) if sum(ι) ≤ K]
end

works, but I am wondering if there is something more elegant (at least not having to iterate over values I don’t use, ideally not having to collect at all but still getting a well-defined length).

Example output is

julia> simplex_integer_grid(Val(3), 2)
10-element Array{Tuple{Int64,Int64,Int64},1}:
 (0, 0, 0)
 (1, 0, 0)
 (2, 0, 0)
 (0, 1, 0)
 (1, 1, 0)
 (0, 2, 0)
 (0, 0, 1)
 (1, 0, 1)
 (0, 1, 1)
 (0, 0, 2)

The set of all x::NTuple{N,Int} with all(x .>= 0) && sum(x) <= K is in bijection with subsets of {1, ..., N + K} of N elements, via x -> cumsum(x .+ 1) (unfortunately cumsum for tuples is only available on julia master).

So maybe you can use IterTools.subsets(1:(K + N), Val{N}) to iterate the subsets (the iterator also has length defined, using binomial) and then translate from subset to simplex gridpoint, which should be as simple as v -> v .- (0, Base.front(v)...) .- 1.

In practice:

julia> using IterTools

julia> s = subsets(1:5, Val{3}());

julia> itr = (v .- (0, Base.front(v)...) .- 1 for v in s);

julia> collect(itr)
10-element Array{Tuple{Int64,Int64,Int64},1}:
 (0, 0, 0)
 (0, 0, 1)
 (0, 0, 2)
 (0, 1, 0)
 (0, 1, 1)
 (0, 2, 0)
 (1, 0, 0)
 (1, 0, 1)
 (1, 1, 0)
 (2, 0, 0)
1 Like