Efficient filter for product iterator

while solving a Kakuro puzzle I decided to brute force a section of the puzzle (see image with 8 squares marked x1…x8)

I wrote a succinct iterator based approach (f1() in the code below) which is not very efficient compared to 8 nested for loops (f2() in the code below).
runtime comparison:

julia> @time x = f1();
  0.306097 seconds (14.73 M allocations: 1.243 GiB, 18.50% gc time)

julia> @time x = f2();
  0.000110 seconds (1.86 k allocations: 58.250 KiB)

The main issue is that the iterator-based code (f1()) checks all constraints and continues to cycle the product iterator in the order of the iterators rather than skipping the relevant iterator.

Is there a syntax for combining filter and product which is both efficient, succinct, and readable?

code:

ts(s, n) = length(s) == length(unique(s)) && sum(s) == n

function f1()
    for x ∈ Iterators.product((1:9 for i in 1:8)...)
        ts((x[1], x[2]), 6) || continue
        ts((x[3], x[4], x[5]), 20) || continue
        ts((x[6], x[7], x[8]), 10) || continue
        ts((x[3], x[6]), 11) || continue
        ts((9, 7, x[1], x[4], x[7]), 33) || continue
        ts((x[2], x[5], x[8]), 8) || continue

        return x
    end
end


function f2()
    x = zeros(Int, 8)
    for x[1] ∈ vcat(1:6,8)
        for x[2] ∈ 1:9
            x[2] == x[1] && continue
            x[1] + x[2] != 6 && continue
            for x[3] ∈ 1:9
                for x[4] ∈ vcat(1:6,8)
                    x[4] == x[1] && continue
                    for x[5] ∈ 1:9
                        x[5] ∈ (x[2], x[3], x[4]) && continue
                        x[3] + x[4] + x[5] != 20 && continue
                        for x[6] ∈ 1:9
                            x[6] == x[3] && continue
                            x[3] + x[6] != 11 && continue
                            for x[7] ∈ vcat(1:6,8)
                                x[7] ∈ (x[1], x[4], x[6]) && continue
                                x[1] + x[4] + x[7] != (33-9-7) && continue
                                for x[8] ∈ 1:9
                                    x[8] ∈ (x[2], x[5], x[6], x[7]) && continue
                                    x[2] + x[5] + x[8] != 8 && continue
                                    x[6] + x[7] + x[8] != 10 && continue
                                    return x
                                end
                            end
                        end
                    end
                end
            end
        end
    end
end

This doesn’t answer your question, but writing ts(s, n) = allunique(s) && sum(s) == n should be much faster, 200x here. This has a special method for tuples, whereas unique((1,2,3)) isa Vector. Edit: This may need Julia 1.9, I forgot.

Still 100x away from f2() for the reason you state.

1 Like

The problem is that generation and filtering have to be interleaved to avoid generating all possible combinations. This will not be possible with the product iterator as it would need to be lazy and track which variables depend on which filter conditions.
Your best bet – for a readable and efficient way – are probably comprehensions:

function f3()
           [(x₁, x₂, x₃, x₄, x₅, x₆, x₇, x₈)
            for x₁ ∈ 1:9
            for x₂ ∈ 1:9
            if ts((x₁, x₂), 6)
            for x₃ ∈ 1:9
            for x₄ ∈ 1:9
            for x₅ ∈ 1:9
            if ts((x₃, x₄, x₅), 20)
            for x₆ ∈ 1:9
            if ts((x₃, x₆), 11)
            for x₇ ∈ 1:9
            if ts((9, 7, x₁, x₄, x₇), 33)
            for x₈ ∈ 1:9
            if ts((x₆, x₇, x₈), 10) && ts((x₂, x₅, x₈), 8)]
       end
1 Like

Some benchmarking on my system (Julia 1.9.0-alpha1):

OP:

julia> @btime f1();
  350.724 ms (14725068 allocations: 1.24 GiB)

julia> @btime f2();
  99.500 μs (1861 allocations: 58.25 KiB)

With @bertschi’s recommendation for f3:

julia> @btime f3();
  615.700 μs (29588 allocations: 2.56 MiB)

Adding @mcabbott’s recommendation for ts to use allunique:

julia> @btime f1();
  2.122 ms (0 allocations: 0 bytes)

julia> @btime f2(); # doesn't use `ts`
  99.500 μs (1861 allocations: 58.25 KiB)

julia> @btime f3();
  10.600 μs (470 allocations: 40.36 KiB)

so many orders of magnitude :sweat_smile:

Here’s a recursive solution which encodes your constraints at each level:

function f4()
    x = zeros(Int, 8)
    level(x, Val(0)) ? x : nothing
end

level(x, v) = next(x, v)
function next(x, ::Val{N}) where N
    for i in 1:9
        x[N+1] = i
        level(x, Val(N+1)) && return true
    end
    false
end

# ts((x[1], x[2]), 6) || continue
level(x, v::Val{2}) = ts((x[1], x[2]), 6) && next(x, v)

# ts((x[3], x[4], x[5]), 20) || continue
level(x, v::Val{5}) = ts((x[3], x[4], x[5]), 20) && next(x, v)

# ts((x[3], x[6]), 11) || continue
level(x, v::Val{6}) = ts((x[3], x[6]), 11) && next(x, v)

# ts((9, 7, x[1], x[4], x[7]), 33) || continue
level(x, v::Val{7}) = ts((9, 7, x[1], x[4], x[7]), 33) && next(x, v)

# ts((x[6], x[7], x[8]), 10) || continue
# ts((x[2], x[5], x[8]), 8) || continue
level(x, v::Val{8}) = ts((x[6], x[7], x[8]), 10) && ts((x[2], x[5], x[8]), 8)

ts(s, n) = _allunique(s) && sum(s) == n

function _allunique(t::Tuple)  # from 1.9
    a = Base.afoldl(true, Base.tail(t)...) do b, x
        b & !isequal(first(t), x)
    end
    return a && _allunique(Base.tail(t))
end
_allunique(t::Tuple{}) = true

#=

julia> f2() ==  @btime f4()
  min 7.031 μs, mean 7.234 μs (1 allocation, 128 bytes)
true

julia> @btime f2();
  min 51.334 μs, mean 54.409 μs (621 allocations, 27.25 KiB)

julia> @btime f2cse();  # using @cse to pull out vcat
  min 4.006 μs, mean 4.142 μs (11 allocations, 576 bytes)
true

julia> @btime f3()  # from @bertschi above
  min 16.166 μs, mean 24.642 μs (470 allocations, 40.36 KiB)
1-element Vector{NTuple{8, Int64}}:
 (4, 2, 7, 8, 5, 4, 5, 1)

=#

Note that f2() spends most of its time doing vcat(1:6,8). If you pull this out, it is becomes a little quicker than f4().

Have not thought much; there may still be much better ways of doing this.

1 Like