Improve removing members of one array from another

The following Julua code is much slower than Ruby and Python because of the slow way it removes the members of one array from another.

function prime_pairs_lohi(n)
  if (n&1 == 1 || n < 4) return println("Input not even n > 2") end
  if (n <= 6) println([n, 1]); println([div(n,2),div(n,2)]); println([div(n,2),div(n,2)]); return end

  # generate the low-half-residues (lhr) r < n/2
  ndiv2, rhi = div(n,2), n-2         # lhr:hhr midpoint, max residue limit
  lhr = []                           # array for low half residues
  for r in 3:2:ndiv2; if (gcd(r, n) == 1) append!(lhr, r) end end

  # store all powers and cross-products of the lhr members < n-2
  lhr_mults = []                     # lhr multiples, not part of a pcp
  i = 2; ln = length(lhr)
  while (i < ln)
    r = lhr[i-1]                     # make 1st lhr reference multiplier
    rmax = div(rhi, r)               # ri can't multiply r with values > this
    if (r < rmax) append!(lhr_mults, r*r) end # for r^2 multiples
    if (lhr[i] > rmax) break end     # exit if product of consecutive r’s > n-2
    j = i                            # index to next larger residue in lhr
    while (j < ln)                   # for each residue in reduced list
      if (lhr[j] > rmax) break end   # exit for r if cross-product with ri > n-2
      append!(lhr_mults, r * lhr[j]) # store value if < n-2
      j += 1                         # next lhr value to multiply
    end
    i += 1
  end

  # convert lhr_mults vals > n/2 to their lhr complements n-r,
  # store them, those < n/2, in lhr_del; it now holds non-pcp lhr vals
  lhr_del = []
  for r in lhr_mults                 # convert all lhr multiples to lhr values
    r_del = (r > ndiv2) ? n - r : r  # convert r > n/2 to lhr mc if necessary
    append!(lhr_del, r_del)          # store in lhr_del
  end

  lhr = setdiff(lhr,lhr_del)         # remove lhr_del multiples from lhr
  println([n,  length(lhr)])         # show n and pcp prime pairs count
  println([first(lhr),n-first(lhr)]) # show first pcp prime pair of n
  println([last(lhr), n-last(lhr)])  # show last  pcp prime pair of n
end

num = readline()                     # get n value string from terminal
n = parse(Int64, num)                # convert to Int64 integer value
@time begin prime_pairs_lohi(n) end  # execute code and show its timing

In Ruby|Crystal its: lhr -= lhr_del

In Pyrhon its: lhr = [r for r in lhr if r not in lhr_del]

In Julia this is mag orders slower: lhr = setdiff(lhr,lhr_del)

Is there a way to do this faster, and in general, to optimize all the code.

Are you sure it’s that specific line? Did you profile the code?

If you just compared the times you got from this script then these likely are heavily dominated by compilation time. Please switch to @btime from BenchmarkTools.jl for proper timing.
It’d also be great, if you could post some example times.

I think you can already gain significant speed by ensuring that

  • your vectors are not Vector{Any} but instead Vector{Int}
  • you replace append! with push! (which is the correct way to add a single element in Julia)
  • you pre-allocate the correct size (or a mild overestimate) for your vectors

Here’s a new version that runs at least 5 times faster, if my benchmarks are correct:

function prime_pairs_lohi_fast(n::T) where {T<:Integer}
    if (n & 1 == 1 || n < 4)
        return println("Input not even n > 2")
    end
    if (n <= 6)
        println([n, 1])
        println([div(n, 2), div(n, 2)])
        println([div(n, 2), div(n, 2)])
        return nothing
    end

    # generate the low-half-residues (lhr) r < n/2
    ndiv2, rhi = div(n, 2), n - 2         # lhr:hhr midpoint, max residue limit
    lhr = T[]                           # array for low half residues
    sizehint!(lhr, ndiv2 Ă· 2)
    for r in 3:2:ndiv2
        if (gcd(r, n) == 1)
            push!(lhr, r)
        end
    end

    # store all powers and cross-products of the lhr members < n-2
    lhr_mults = T[]                     # lhr multiples, not part of a pcp
    sizehint!(lhr_mults, length(lhr)^2)
    i = 2
    ln = length(lhr)
    while (i < ln)
        r = lhr[i - 1]                     # make 1st lhr reference multiplier
        rmax = div(rhi, r)               # ri can't multiply r with values > this
        if (r < rmax)
            push!(lhr_mults, r * r)
        end # for r^2 multiples
        if (lhr[i] > rmax)
            break
        end     # exit if product of consecutive r’s > n-2
        j = i                            # index to next larger residue in lhr
        while (j < ln)                   # for each residue in reduced list
            if (lhr[j] > rmax)
                break
            end   # exit for r if cross-product with ri > n-2
            push!(lhr_mults, r * lhr[j]) # store value if < n-2
            j += 1                         # next lhr value to multiply
        end
        i += 1
    end

    # convert lhr_mults vals > n/2 to their lhr complements n-r,
    # store them, those < n/2, in lhr_del; it now holds non-pcp lhr vals
    lhr_del = T[]
    sizehint!(lhr_del, length(lhr_mults))
    for r in lhr_mults                 # convert all lhr multiples to lhr values
        r_del = (r > ndiv2) ? n - r : r  # convert r > n/2 to lhr mc if necessary
        push!(lhr_del, r_del)          # store in lhr_del
    end

    lhr = setdiff(lhr, lhr_del)         # remove lhr_del multiples from lhr
    println([n, length(lhr)])         # show n and pcp prime pairs count
    println([first(lhr), n - first(lhr)]) # show first pcp prime pair of n
    println([last(lhr), n - last(lhr)])  # show last  pcp prime pair of n
    return nothing
end

Note the creation of vectors with T[] instead of [] (which would be your Pythonic reflex), and the use of sizehint!

Benchmark results (with printing commented out):

julia> using BenchmarkTools

julia> @btime prime_pairs_lohi(1_000_000);
  114.282 ms (6624625 allocations: 147.57 MiB)

julia> @btime prime_pairs_lohi_fast(1_000_000);
  18.351 ms (27 allocations: 298.03 GiB)

It seems my size estimates are excessive, probably because the break condition is hit fairly often. So you may want to use a better heuristic (I don’t understand what the code does so I picked numbers that looked about right). Even if it’s not an overestimate, that won’t impact the correctness of your code.

2 Likes

I know you intended to reduce resizing allocations, but it’s resulting in a massive increase in allocated memory. sizehint!(lhr_mults, length(lhr)^2) in particular crashes on my system with OutOfMemoryError. FWIW, the original function culminates in length.((lhr, lhr_mults, lhr_del)) = (199999, 893066, 893066), so the estimate can be improved.

Are you heavily using println and readline because you’re working from the command line? There are better ways of doing IO.

I didn’t read the entirety of your code, but I suppose you’d rather want setdiff!(lhr, lhr_del), where setdiff! is like setdiff but mutates the first argument instead of returning a result.

NB: the setdiff doc string points to setdiff! under “see also”. I try to read all relevant doc strings.

It might be more efficient to avoid constructing lhr_del at all and just delete elements one-by-one? Not sure.

I took the code suggestions and played with them, and this is the fastest version I’ve been able to create so far.

  1. I got rid of creating the lhr_mutls array.
  2. Using setdiff is way faster than setdiff!.

To make the code operationally the same as Ruby|Python, need to take inputs with uderscores: 123_456 and take inputs directly off the cli.
$ julia prime_pairs_lohi 123_456

Does Julia have something like n.Args[0] to take inputs direclty from cli?

This implementation is now faster than Ruby|Crystal|Python for most inputs I tested, and faster than D for some (smaller) values. I’ll do more timing tests when I get time and post some results.

function prime_pairs_lohi(n::T) where {T<:Integer}
  if (n & 1 == 1 || n < 4); return println("Input not even n > 2") end
  if (n <= 6); println([n, 1]); println([div(n, 2), div(n, 2)]);
     println([div(n, 2), div(n, 2)]); return
  end

  # generate the low-half-residues (lhr) r < n/2
  ndiv2, rhi = div(n, 2), n - 2        # lhr:hhr midpoint, max residue limit
  lhr = T[]                            # array for low half residues
  sizehint!(lhr, ndiv2 Ă· 2)
  for r in 3:2:ndiv2; if (gcd(r, n) == 1) push!(lhr, r) end end

  # store all powers and cross-products of the lhr members < n-2
  lhr_del = T[]                        # lhr multiples, not part of a pcp
  sizehint!(lhr_del, length(lhr))
  i = 2; ln = length(lhr)
  while (i < ln)                       # iterate thru lhr to find prime multiples
    r = lhr[i - 1]                     # make current lhr reference multiplier
    rmax = div(rhi, r)                 # ri can't multiply r with values > this
    if (r < rmax)
       r2 = r*r; if (r2 > ndiv2) r2 = n - r2 end
       push!(lhr_del, r2) end          # for r^2 multiples
    if (lhr[i] > rmax) break end       # exit if product of consecutive r’s > n-2
    j = i                              # index to next larger residue in lhr
    while (j < ln)                     # for each residue in reduced list
      if (lhr[j] > rmax) break end     # exit for r if cross-product with ri > n-2
      r3 = r * lhr[j]; if (r3 > ndiv2) r3 = n - r3 end
      push!(lhr_del, r3)               # store value if < n-2
      j += 1                           # next lhr value to multiply
    end
    i += 1
  end

  # remove from lhr its lhr_mults, convert vals > n/2 to lhr complements first
  lhr = setdiff(lhr, lhr_del)          # remove lhr_del multiples from lhr

  println([n, length(lhr)])            # show n and pcp prime pairs count
  println([first(lhr), n-first(lhr)])  # show first pcp prime pair of n
  println([last(lhr),  n-last(lhr)])   # show last  pcp prime pair of n
  return
end

num = readline()                       # get n value string from terminal
n = parse(Int64, num)                  # convert to Int64 integer value
@time begin prime_pairs_lohi(n) end    # execute code and show its timing

Yeah the estimate can definitely be improved, that’s what I meant by

I suspect there are some number-theoretic arguments that would provide decent approximations [coughs in n \log n]. But of course, getting rid of that big array is even better.

1 Like

Right at the top of the Command-line Interface docs:

When running a script using julia, you can pass additional arguments to your script:

$ julia script.jl arg1 arg2...

These additional command-line arguments are passed in the global constant ARGS.

Just process ARGS into your inputs and pass them into your function calls, or process ARGS in your function, whichever works.

However, starting up a Julia process, reevaluating your whole script, and recompiling your main function only to execute it once is eating up time. The internal @time begin prime_pairs_lohi(n) end would measure the compilation and execution of the main function, but it’s missing everything beforehand.

If you plan to work from the command line practically, it’s better to use a language that compiles ahead of time to an executable AND doesn’t load much of a runtime. There’s nothing wrong with doing impractical work like reimplementing a program in half a dozen languages to see how fast it can compile and run, but it’s not optimal.

I don’t think you’re aware that this timing is including the compilation time?