Using Combinatorics: permutations much slower than copied source code

When trying to solve an alphametics problem with permutations from Combinatorics.jl, I came across something I don’t understand about the performance.

If I introduce permutations to the namespace via using Combinatorics: permutations, the performance on large problems is terribly slow compared to when I copy and paste the source code directly from Combinatorics.jl/src/permutations.jl

An example algorithm for solving an alphametics puzzle:

function solve(puzzle)
    terms = reverse.(split(replace(puzzle, r"[+=]"=> ""), "  "))
    leading, letters = Set(last.(terms)), union(Set(terms)...)
    for perm in permutations(0:9, length(letters))
        solution = Dict(zip(letters, perm))
        all(lead-> !iszero(solution[lead]), leading) && trial(solution, terms) && return solution
    end
end

function trial(dict, terms)
    sum(sum(dict[t]*10^(n-1) for (n,t) in enumerate(term)) for term in terms[1:end-1]) == sum(dict[t]*10^(n-1) for (n,t) in enumerate(last(terms)))
end

Shows something like:

julia> using Combinatorics: permutations;

julia> @time solve("AND + A + STRONG + OFFENSE + AS + A + GOOD == DEFENSE");
90.300626 seconds (27.97 M allocations: 1.320 GiB, 0.17% gc time, 0.09% compilation time)

However, when I cut and paste the file from source (after restarting the kernel), I see something like:

julia> @time solve("AND + A + STRONG + OFFENSE + AS + A + GOOD == DEFENSE");
1.310297 seconds (37.97 M allocations: 2.005 GiB, 14.97% gc time, 7.91% compilation time)

This becomes much worse with a larger example, so I’m wondering why this happens (and potentially where). Since there is a significant difference not only in execution time, but also an inverse one with memory allocation, I would assume there is some sort of optimization happening somewhere which is prioritizing space over time, but I am ignorant as to what that process could be or if there is some way to get around it. For my case, low run time is of more importance, but I would like to be able to use the package (via using Combinatorics) rather than explicitly copying in all the source code.

Final note: I’ve noticed this happening both my local environment and in a Docker image run on AWS.

This doesn’t explain the poor performance of permutations, but if you use Arrangements from Combinat.jl, then it’s fast:

julia> using Combinatorics, Combinat, Chairmarks

julia> @b sum(first(a) for a in Combinatorics.permutations(1:10, 8))
4.178 s (5443202 allocs: 276.856 MiB, 0.87% gc time, without a warmup)

julia> @b sum(first(a) for a in Combinat.Arrangements(1:10, 8))
65.762 ms (3628898 allocs: 221.490 MiB, 25.23% gc time)

Same for SmallCombinatorics.jl:

julia> using SmallCollections, SmallCombinatorics, Chairmarks

julia> @b sum(first(a) for s in subsets(10, 8) for a in permutations(SmallVector{16,Int8}(s)))
10.767 ms (134 allocs: 2.797 KiB)

I cannot reproduce your observation. Your example takes close to 4min on my laptop, no matter if I use Combinatorics.permutations or the copied source code.

I haven’t looked to see what the source for permutations looks like, but the implementation can be made faster.

function solve_fast(puzzle::AbstractString)
    # Extract terms as words of A–Z
    raw_terms = [m.match for m in eachmatch(r"[A-Z]+", puzzle)]
    nterms     = length(raw_terms)

    # Reverse each term so index 1 is units, index 2 is tens, etc.
    terms = reverse.(raw_terms)

    # Leading letters are the first character of each original term,
    # i.e. the last character of each reversed term.
    leading = Set(last.(terms))

    # Collect all distinct letters, in stable order
    letters = unique(vcat(collect.(raw_terms)...))
    L = length(letters)
    @assert L <= 10  "More than 10 distinct letters; impossible to assign digits 0–9."

    letters_vec = collect(letters)
    letter_index = Dict(c => i for (i, c) in pairs(letters_vec))

    # Precompute powers of 10 up to the longest term
    maxlen = maximum(length.(terms))
    pow10  = [10^(k-1) for k in 1:maxlen]

    # Compute coefficient of each letter in the equation sum(coeff[l] * digit[l]) == 0
    coeffs = zeros(Int, L)
    for (i, term) in pairs(terms)
        sign = (i == nterms) ? -1 : 1  # last term on RHS with negative sign
        for (pos, ch) in enumerate(term)
            idx = letter_index[ch]
            coeffs[idx] += sign * pow10[pos]
        end
    end

    # Order letters by descending |coeff| for better pruning
    perm_order     = sortperm(abs.(coeffs), rev = true)
    letters_sorted = letters_vec[perm_order]
    coeffs_sorted  = coeffs[perm_order]

    # Leading-letter constraint in sorted order
    leading_mask = map(l -> l in leading, letters_sorted)

    # Backtracking state
    assignment = fill(-1, L)            # digit assigned to letters_sorted[i]
    used_mask  = UInt16(0)              # bit d set means digit d is used

    function dfs(pos::Int, sum_so_far::Int, used::UInt16)::Bool
        if pos > L
            return sum_so_far == 0
        end

        c = coeffs_sorted[pos]

        @inbounds for d in 0:9
            bit = UInt16(1) << d
            (used & bit != 0) && continue               # digit already used
            (d == 0 && leading_mask[pos]) && continue   # leading letter cannot be zero

            assignment[pos] = d
            if dfs(pos + 1, sum_so_far + c * d, used | bit)
                return true
            end
        end

        return false
    end

    dfs(1, 0, used_mask) || return nothing

    # assignment currently matches letters_sorted; map back to original letters order
    digits_sorted = assignment
    digits = similar(digits_sorted)
    for (i, ord_i) in enumerate(perm_order)
        digits[ord_i] = digits_sorted[i]
    end

    return Dict(letters_vec .=> digits)
end

# Your original trial function still works with the resulting Dict
function trial(dict, terms)
    sum(sum(dict[t]*10^(n-1) for (n,t) in enumerate(term)) for term in terms[1:end-1]) ==
    sum(dict[t]*10^(n-1) for (n,t) in enumerate(last(terms)))
end

# Example
@time solve_fast("AND + A + STRONG + OFFENSE + AS + A + GOOD == DEFENSE")
  0.030308 seconds (704.72 k allocations: 11.308 MiB, 23.16% gc time, 23.04% compilation time) [slightly faster the second time]
Dict{Char, Int64} with 10 entries:
  'O' => 2
  'F' => 7
  'D' => 3
  'A' => 5
  'E' => 4
  'R' => 1
  'G' => 8
  'S' => 6
  'N' => 0
  'T' => 9

Here is an adaptation of @depial’s original function to SmallCombinatorics.jl:

using SmallCollections: SmallVector, SmallDict
using SmallCombinatorics: subsets, permutations

function small_solve(puzzle)
    terms = reverse.(split(replace(puzzle, r"[+=]"=> ""), "  "))
    leading, letters = unique!(last.(terms)), union(Set(terms)...)
    for s in subsets(10, length(letters)), perm in permutations(SmallVector{16,Int8}(k-1 for k in s))
        solution = SmallDict{16,Char,Int8}(l => p for (l, p) in zip(letters, perm); unique = true)
        all(lead-> !iszero(solution[lead]), leading) && trial(solution, terms) && return solution
    end
end

function trial(dict, terms)
    sum(sum(dict[t]*10^(n-1) for (n,t) in enumerate(term)) for term in @view(terms[1:end-1])) == sum(dict[t]*10^(n-1) for (n,t) in enumerate(last(terms)))
end

It is not as fast as @technocrat’s solution, but not too far away:

julia> using Chairmarks

julia> @b small_solve("AND + A + STRONG + OFFENSE + AS + A + GOOD == DEFENSE")
53.532 ms (1175539 allocs: 62.854 MiB, 6.06% gc time)

julia> @b solve_fast("AND + A + STRONG + OFFENSE + AS + A + GOOD == DEFENSE")
28.341 ms (689345 allocs: 10.523 MiB)
1 Like

Thank you for your replies! However, the algorithm itself is of no importance. I’ve only included this one because it was concise and would be potentially easier to use.

The main issue is the discrepancy between using the Combinatorics package vs using the source code from permutations.jl. I’m unsure why @matthias314 couldn’t reproduce the issue since I’m able to consistently in two different environments, with different algorithms.

Here is the code I was using to compare the timings.
if false
    @info "Combinatorics"
    using Combinatorics: permutations
else
    @info "no Combinatorics"

struct Permutations{T}
    data::T
    length::Int
end

function has_repeats(state::Vector{Int})
    # This can be safely marked inbounds because of the type restriction in the signature.
    # If the type restriction is ever loosened, please check safety of the `@inbounds`
    @inbounds for outer in eachindex(state)
        for inner in (outer+1):lastindex(state)
            if state[outer] == state[inner]
                return true
            end
        end
    end
    return false
end

function increment!(state::Vector{Int}, min::Int, max::Int)
    state[end] += 1
    for i in reverse(eachindex(state))[firstindex(state):end-1]
        if state[i] > max
            state[i] = min
            state[i-1] += 1
        end
    end
end

function next_permutation!(state::Vector{Int}, min::Int, max::Int)
    while true
        increment!(state, min, max)
        has_repeats(state) || break
    end
end

function Base.iterate(p::Permutations, state::Vector{Int}=fill(firstindex(p.data), p.length))
    next_permutation!(state, firstindex(p.data), lastindex(p.data))
    if first(state) > lastindex(p.data)
        return nothing
    end
    [p.data[i] for i in state], state
end

function Base.length(p::Permutations)
    length(p.data) < p.length && return 0
    return Int(prod(length(p.data) - p.length + 1:length(p.data)))
end

Base.eltype(p::Permutations) = Vector{eltype(p.data)}

Base.IteratorSize(p::Permutations) = Base.HasLength()

permutations(a) = permutations(a, length(a))

function permutations(a, t::Integer)
    if t == 0
        # Correct behavior for a permutation of length 0 is a vector containing a single empty vector
        return [Vector{eltype(a)}()]
    elseif t == 1
        # Easy case, just return each element in its own vector
        return [[ai] for ai in a]
    elseif t < 0 || t > length(a)
        # Correct behavior for a permutation of these lengths is a an empty vector (of the correct type)
        return Vector{Vector{eltype(a)}}()
    end
    return Permutations(a, t)
end

end

# benchmark

using Chairmarks

function solve(puzzle)
    terms = reverse.(split(replace(puzzle, r"[+=]"=> ""), "  "))
    leading, letters = Set(last.(terms)), union(Set(terms)...)
    for perm in permutations(0:9, length(letters))
        solution = Dict(zip(letters, perm))
        all(lead-> !iszero(solution[lead]), leading) && trial(solution, terms) && return solution
    end
end

function trial(dict, terms)
    sum(sum(dict[t]*10^(n-1) for (n,t) in enumerate(term)) for term in terms[1:end-1]) == sum(dict[t]*10^(n-1) for (n,t) in enumerate(last(terms)))
end

display(@b solve("AND + A + STRONG + OFFENSE + AS + A + GOOD == DEFENSE"))

I don’t recognize that iterate method, but I do recognize the comments in has_repeats from a PR I came across when looking for information on my issue… We may be looking at different code. Do you have a link?

I’ve been using this, which is copied from the link I’ve provided above (minus the docstrings):

Code from link in description

struct Permutations{T}
    data::T
    length::Int
end

function Base.iterate(p::Permutations, state=nothing)
    if state === nothing
        mp = multiset_permutations(collect(eachindex(p.data)), p.length)
        it = iterate(mp)
        if it === nothing return nothing end
    else
        mp, mp_state = state
        it = iterate(mp, mp_state)
        if it === nothing return nothing end
    end
    indices, mp_state = it
    return [p.data[i] for i in indices], (mp=mp, mp_state=mp_state)
end

function Base.length(p::Permutations)
    length(p.data) < p.length && return 0
    return Int(prod(length(p.data) - p.length + 1:length(p.data)))
end

Base.eltype(p::Permutations) = Vector{eltype(p.data)}

Base.IteratorSize(p::Permutations) = Base.HasLength()

permutations(a) = permutations(a, length(a))

function permutations(a, t::Integer)
    if t == 0
        # Correct behavior for a permutation of length 0 is a vector containing a single empty vector
        return [Vector{eltype(a)}()]
    elseif t == 1
        # Easy case, just return each element in its own vector
        return [[ai] for ai in a]
    elseif t < 0 || t > length(a)
        # Correct behavior for a permutation of these lengths is a an empty vector (of the correct type)
        return Vector{Vector{eltype(a)}}()
    end
    return Permutations(a, t)
end

derangements(a) = (d for d in multiset_permutations(a, length(a)) if all(t -> t[1] != t[2], zip(a, d)))

function nextpermutation(m, t, state)
    perm = [m[state[i]] for i in 1:t]
    n = length(state)
    if t <= 0
        return (perm, [n + 1])
    end
    s = copy(state)
    if t < n
        j = t + 1
        while j <= n && s[t] >= s[j]
            j += 1
        end
    end
    if t < n && j <= n
        s[t], s[j] = s[j], s[t]
    else
        if t < n
            reverse!(s, t + 1)
        end
        i = t - 1
        while i >= 1 && s[i] >= s[i+1]
            i -= 1
        end
        if i > 0
            j = n
            while j > i && s[i] >= s[j]
                j -= 1
            end
            s[i], s[j] = s[j], s[i]
            reverse!(s, i + 1)
        else
            s[1] = n + 1
        end
    end
    return (perm, s)
end

struct MultiSetPermutations{T}
    m::T
    f::Vector{Int}
    t::Int
    ref::Vector{Int}
end

Base.eltype(::Type{MultiSetPermutations{T}}) where {T} = Vector{eltype(T)}

function Base.length(c::MultiSetPermutations)
    t = c.t
    if t > length(c.ref)
        return 0
    end
    if t > 20
        g = [factorial(big(i)) for i in 0:t]
    else
        g = [factorial(i) for i in 0:t]
    end
    p = [g[t+1]; zeros(Float64, t)]
    for i in 1:length(c.f)
        f = c.f[i]
        if i == 1
            for j in 1:min(f, t)
                p[j+1] = g[t+1] / g[j+1]
            end
        else
            for j in t:-1:1
                q = 0
                for k in (j+1):-1:max(1, j + 1 - f)
                    q += p[k] / g[j+2-k]
                end
                p[j+1] = q
            end
        end
    end
    return round(Int, p[t+1])
end

multiset_permutations(a) = multiset_permutations(a, length(a))

function multiset_permutations(a, t::Integer)
    m = unique(a)
    f = [sum(c == x for c in a)::Int for x in m]
    multiset_permutations(m, f, t)
end

function multiset_permutations(m, f::Vector{<:Integer}, t::Integer)
    length(m) == length(f) || error("Lengths of m and f are not the same.")
    ref = length(f) > 0 ? vcat([[i for j in 1:f[i]] for i in 1:length(f)]...) : Int[]
    if t < 0
        t = length(ref) + 1
    end
    MultiSetPermutations(m, f, t, ref)
end

function Base.iterate(p::MultiSetPermutations, s=p.ref)
    (!isempty(s) && max(s[1], p.t) > length(p.ref) || (isempty(s) && p.t > 0)) && return
    nextpermutation(p.m, p.t, s)
end

function nthperm!(a::AbstractVector, k::Integer)
    n = length(a)
    n == 0 && return a
    f = factorial(oftype(k, n))
    0 < k <= f || throw(ArgumentError("permutation k must satisfy 0 < k ≤ $f, got $k"))
    k -= 1 # make k 1-indexed
    for i = 1:n-1
        f ÷= n - i + 1
        j = k ÷ f
        k -= j * f
        j += i
        elt = a[j]
        for d = j:-1:i+1
            a[d] = a[d-1]
        end
        a[i] = elt
    end
    a
end

nthperm(a::AbstractVector, k::Integer) = nthperm!(collect(a), k)

function nthperm(p::AbstractVector{<:Integer})
    isperm(p) || throw(ArgumentError("argument is not a permutation"))
    k, n = 1, length(p)
    for i = 1:n-1
        f = factorial(n - i)
        for j = i+1:n
            k += ifelse(p[j] < p[i], f, 0)
        end
    end
    return k
end

const levicivita_lut = cat([0 0 0; 0 0 1; 0 -1 0],
                           [0 0 -1; 0 0 0; 1 0 0],
                           [0 1 0; -1 0 0; 0 0 0];
                           dims=3)

function levicivita(p::AbstractVector{<:Integer})
    n = length(p)

    if n == 3
        @inbounds valid = (0 < p[1] <= 3) * (0 < p[2] <= 3) * (0 < p[3] <= 3)
        return valid ? levicivita_lut[p[1], p[2], p[3]] : 0
    end

    todo = trues(n)
    first = 1
    cycles = flips = 0

    while cycles + flips < n
        first = coalesce(findnext(todo, first), 0)
        (todo[first] = !todo[first]) && return 0
        j = p[first]
        (0 < j <= n) || return 0
        cycles += 1
        while j ≠ first
            (todo[j] = !todo[j]) && return 0
            j = p[j]
            (0 < j <= n) || return 0
            flips += 1
        end
    end

    return iseven(flips) ? 1 : -1
end

function parity(p::AbstractVector{<:Integer})
    epsilon = levicivita(p)
    epsilon == 0 && throw(ArgumentError("Not a permutation"))
    epsilon == 1 ? 0 : 1
end

There’s my problem. The code in my link above is what is up-to-date in the repo, but not what is actually in the current release. What I find on my machine does match with what @matthias314 provided.

Thanks for the help!

1 Like