Interrupt `Threads.@threads` computation

I sometime deal with very lengthy computation that I would like to interrupt before they finish.
I can achieve this by wrapping the function with @async and then throw an InterruptException() to the task when I want it to end.

Unfortunately, this does not seem to work if my function uses multithreading via Threads.@threads internally as I believe @threads spawn other tasks that keep going when I interrupt the originating task.

Is there a way to exploit multithreaded for with the possibility of killing the computation all-together?

For whoever might get in the same situation, I managed to hack my way out of it for a quick fix.

I was just trying some simple toy example like with the commands in this code

function testfunc()
	portion = zeros(Int,Threads.nthreads())
	Threads.@threads for i ∈ 1:400
		rand(1000,1000)*rand(1000,1000)
		portion[Threads.threadid()] += 1
		iter = sum(portion)
		mod(iter,10) === 0 && (println("iter = $iter, Thread = $(Threads.threadid())"); sleep(0.1))
	end
end

tt = @async testfunc()

schedule(tt, InterruptException();error=true)

Where you’ll see that after scheduling the interrupt exception the threads continue to print the status via println.

I had a look at the code of Threads.@threads and I could get thread interruptions by modifying sligthly the function at julia/threadingconstructs.jl at d279aede19db29c5c31696fb213e3101e2230944 · JuliaLang/julia · GitHub

so after overwriting that method with

@eval Threads function threading_run(func)
    ccall(:jl_enter_threaded_region, Cvoid, ())
    n = nthreads()
    tasks = Vector{Task}(undef, n)
    for i = 1:n
        t = Task(func)
        t.sticky = true
        ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, i-1)
        tasks[i] = t
        schedule(t)
    end
    try
        for i = 1:n
            wait(tasks[i])
        end
	catch ex
		if isa(ex, InterruptException)
			println("InterruptException received, stopping @threads tasks!")
			map(tasks) do t
				schedule(t,InterruptException();error = true)
			end
		end
    finally
        ccall(:jl_exit_threaded_region, Cvoid, ())
    end
end

to add the sub-thread interruption in the catch statement, the threads are properly shut down when sending an exception.

This is serious type piracy so I suppose I should create a separate macro with basically the Threads.@threads code copy pasted with the extra catch modification above in a real use case.

Is there any specific reason why the multithreaded call does not already try to shut down the threads it spawned when it is interrupted?

Edit: Had some more discussion on zulip (this thread) where some drawbacks of the method above are pointed out.