Distributed Arrays: Using multiprocessing increaes memory allocations by factor of 200 and slows down

I have 4 processes, each working on a separate part of an array in parallel. Following a previous thread, I use the DistributedArrays package so that every process gets their own part of the array.

In the MWE below, I run the function (a) in a sequential way only on the first process and (b) on 5 processes in parallel. In the first case, we have 2.223 MiB of memory allocation, while in case (b) we have 456.376 MiB of memory allocation, so 200x as much. Note in particular that the measurment does not include the distribution of the array into a distributed array.

What is going wrong? Why is so much memory allocated?

Note that although in version (b) every process works only on 1/5 of the array length, the for loop takes in both cases (a) and (b) almost the same length.

(In my actual use case, I also have a for loop over each element of the array, but the work per element is much larger, hence I use a for loop don’t just map a function like (+) on the elements.)

Time it took to start function (worker): 0.05010485649108887
Time to perform for loop in function (worker):0.8686740398406982
  0.882606 seconds (31.67 k allocations: 2.223 MiB, 3.99% compilation time)
Overall time in sequential computation: 0.9330568313598633
=====================================
Time to distribute array: 2.2283201217651367
Time it took to start function (worker): 0.3570849895477295
My indices are (1:2000000, 1:10)
Time to perform for loop in function (worker):0.8875761032104492
      From worker 2:    Time it took to start function (worker): 0.7719669342041016
      From worker 2:    My indices are (2000001:4000000, 1:10)
      From worker 4:    Time it took to start function (worker): 0.8938190937042236
      From worker 4:    My indices are (6000001:8000000, 1:10)
      From worker 3:    Time it took to start function (worker): 0.9035379886627197
      From worker 3:    My indices are (4000001:6000000, 1:10)
      From worker 5:    Time it took to start function (worker): 0.9050970077514648
      From worker 5:    My indices are (8000001:10000000, 1:10)
      From worker 2:    Time to perform for loop in function (worker):0.7745239734649658
      From worker 5:    Time to perform for loop in function (worker):0.7382068634033203
      From worker 4:    Time to perform for loop in function (worker):0.7611057758331299
      From worker 3:    Time to perform for loop in function (worker):0.8856871128082275
  1.906232 seconds (18.88 M allocations: 456.376 MiB, 4.62% gc time, 26.32% compilation time)
Overall time in parallel computation: 1.9498538970947266

The MWE:

using Distributed
addprocs(4)

@everywhere using DistributedArrays

arr = rand(10^7,10)

# this is the function that each worker should executed in parallel
# input: A view of the array, and the start time of pmap
@everywhere function worker_function(myarr,start_time;do_parallel=true)
    worker_start_time = time()
    println("Time it took to start function (worker): ", worker_start_time - start_time)
    flush(stdout)
    retval = 0
    if do_parallel
        println("My indices are $(localindices(myarr))")
        for i in localindices(myarr)[1]
            retval += myarr[i]
        end
    else
        for i in 1:length(myarr)
            retval += myarr[i]
        end
    end
    println("Time to perform for loop in function (worker):", time() - worker_start_time)
    flush(stdout)
    return retval
end


start_time = time()
@time worker_function(arr,start_time;do_parallel=false)
end_time = time()
println("Overall time in sequential computation: ", end_time - start_time)



println("=====================================")
# we now distibuted the arry to a distributed array
start_time = time()
darr = distribute(arr; procs =procs(),dist=[nprocs(),1])
end_time = time()
println("Time to distribute array: ", end_time - start_time)


start_time = time()
@time begin
    ftrs = [@spawnat p worker_function(darr,start_time;do_parallel=true) for p in procs()]
    for ftr in ftrs
        fetch(ftr)
    end
end
end_time = time()
println("Overall time in parallel computation: ", end_time - start_time)