Write a function that stands up and shuts down workers?

I have a utility in a package to run a batch of things. I’d like for that batch function to be able to take in an argument like n_workers, then stand those workers up, distribute the work, and then remove those new workers. Here’s an example of how things work today, without remote workers:

module MyUtility # Pretend this is a stand-alone package.

using Random
using Distributed

# The expensive function we want to run many times:
function foo(n)
    return randn(n)
end

# A convenient function to run foo for us a whole bunch of times, reducing the result with
# whatever the user provides.
function batch(g, reduce)
    return pmap(g) do n
        reduce(foo(n))
    end
end

end

# Here's my script to use the above package.

import .MyUtility

# We want to run MyUtility.foo a bunch of times, reducing the results of each run with this:
function bar(draws)
    return sqrt(sum(draws.^2)/length(draws))
end

results = MyUtility.batch(1:100, bar)

@show results

This clearly works fine. Further, I’m able to add boilerplate to this to make it work with distributed workers. However, I’d like to reduce the boilerplate that users of the package will need in order for batch to complete its job. I would like for something like this to work:

function batch(g, reduce, n_workers)

    # Get help.
    workers = nothing
    if n_workers > 0
        workers = addprocs(n_workers)
    end

    # Now do the thing.
    results = pmap(g) do n
        reduce(foo(n))
    end

    # Thank our help and release them.
    if n_workers > 0
        rmprocs(workers)
    end

    return results
end

However, this clearly won’t work. Those workers don’t know anything about bar or anything used by bar. Further, if batch is to stand up the workers, then we can’t wrap bar in @everywhere (or rather, we can, but it won’t make it to the right workers). So I’m confused about the right way to do this.

Here’s the “lots of boilerplate example” way to do it, where all of the parallel stuff is pushed on the end user rather than being a convenient part of the utility:

using Distributed
addprocs(5, exeflags="--project=$(Base.active_project())") # Make sure workers inherit our project!

# Bring in packages.
@everywhere begin
    import .MyUtility
end

# This needs to be a separate @everywhere block for some reason.
@everywhere begin
    function bar(draws)
        return sqrt(sum(draws.^2)/length(draws))
    end
end

results = MyUtility.batch(1:100, bar)

Clearly, moving from the “regular” version to a “parallel” version implies a very big reorganization of the user’s code, when all they really intend is to “distribute what I’m doing across more cores”. In fact, the one place where they intend to have a change (“run batch on all my cores”) is the only line that doesn’t change.

Any tips for how to do this? I’ve tried to read all of the related threads but haven’t found a good way.

In some of my own code I’ve used Meta.eval to call @everywhere from inside a function:

#... Inside a function
Meta.eval("@everywhere using MyUtility")

***Edit, also make sure that addprocs activates the current project so that it has access to MyUtility.

p.s. Is there any reason why you can’t use multithreading for your parallelism instead? Something like:

results = Vector{Any}(undef, length(a))
Threads.@threads for i in 1:length(a)
    results[i] = reduce(foo(a[i]))
end

This should have better performance and lower latency and will work as long as foo is not mutating any global state, causing a race condition.

Thanks @jmair. Somehow I was stuck on the idea that I needed pmap! Switching to Threads.@threads is exactly what I needed.

(I wonder why pmap doesn’t use threads, or if there’s an equivalent function that does?)

1 Like

Not in Base but there are ThreadsX.map or ThreadPools.tmap, for example.

2 Likes

pmap is for the Distributed.jl library which is for multiprocessing specifically. As @carstenbauer said, there are some packages that give you pmap like syntax.

Thanks @carstenbauer . ThreadPools.tmap looks like what I was expecting pmap to offer, and tmap seems like an obvious candidate for Threads even. It functions like I expect coming from other languages. Very nice suggestions. Thank you!