Distributed for loop with timeout

I’m trying to run a function, which consists of a nested for loop, in which each iteration writes to a SharedArray, and where the runtime between iterations varies unpredictably, with a timeout such that upon termination of my timeout, the partially filled SharedArray is returned and all running computations are terminated. A simple MWE would be

using Distributed, SharedArrays

addprocs(3)

@everywhere function my_slow_function!(sharedarray, i, j)
    wait = rand(1:10)
    println("Wait for $wait s")
    sleep(wait)
    sharedarray[i, j] = randn()
end

@everywhere function my_wrapper_function(outer_size, inner_size)
    sharedarray = SharedArray{Float64}(outer_size, inner_size)
    fill!(sharedarray, NaN)

    @distributed for i in 1:outer_size
        for j in 1:inner_size
            my_slow_function!(sharedarray, i, j)
        end
    end

    return sharedarray
end

I am running this on a cluster with limited walltime and therefore I owuld like to save all progress and terminate julia before I run out ouf walltime.

Naively, this would look simething like

@everywhere using Dates 

@everywhere function my_wrapper_function(outer_size, inner_size; start = Dates.now(), time_limit = Dates.Second(15))
    sharedarray = SharedArray{Float64}(outer_size, inner_size)
    fill!(sharedarray, NaN)

    @async begin
        while Dates.now() - start_time < time_limit
            sleep(10)  # check every 10 seconds
        end
        ## Kill all running processes of the @distributed loop
    end

    @distributed for i in 1:outer_size
        if Dates.now() - start < time_limit ## Only start if we are within time_limit
            for j in 1:inner_size
                my_slow_function!(sharedarray, i, j)
            end
        else 
            break 
        end 
    end

    return sharedarray
end

However, I am not sure what I would put into my @async block. This seems like such a simple and ubiquitous problem, but I was not able to find anything suitable. Any help is much appreciated.

At the moment, I am simply running my distributed loop and asynchronously check for timeout, and, if it has been reached, I save the partially filled SharedArray and throw an InterruptException(), which in itself does not terminate the running worker processes, but propagates an error so that my cluster manager puts my program out of its mysery. Somehting along the lines of

using Distributed, SharedArrays, Serialization

addprocs(4)

@everywhere begin
    using SharedArrays

    function my_slow_function!(sharedarray, i, j, args...)
        wait = rand(1:10)
        println("Wait for $wait s")
        sleep(wait)
        sharedarray[i, j] = randn()
    end
end

function monitor(start_time, timeout, output, filename)
    while true
        if time() - start_time > timeout
            println("Timeout occurred, saving progress")
            open(filename, "w") do f
                serialize(f, output)
            end
            break
        end
        sleep(1) # adjust the sleep time as needed
    end
end

function my_wrapper_function(outer, inner, timeout, filename)
    output = SharedArray{Float64}(length(outer), length(inner))
    fill!(output, NaN)
    start_time = time()

    @async monitor(start_time, timeout, output, filename)

    @distributed for i in eachindex(outer)
        for j in eachindex(inner)
            my_slow_function!(output, i, j, outer, inner)
        end
    end

    return output
end

Bust surely there is a better way to do things?

Does Distributed.interrupt() help? Maybe something like:

tsk = @async @distributed for ...

sleep(timeout)

if !istaskdone(tsk)
    Distributed.Interrupt()
    println("Timeout occurred, saving progress")
end

open(filename, "w") do f
    serialize(f, output)
end