Can we have inferable fetch(task)?

Continuing the discussion in ANN: Parallel `for` loops in FLoops.jl with composable and extensible fold-based API - #18 by c42f, I wonder if we can make fetch(task) inferable using invoke_in_world added in

Here is a quick POC:

module InferableTasks

export @iasync, @ispawn

struct InferableTask{T}
    task::Task
end

Base.fetch(t::InferableTask{T}) where {T} = fetch(t.task)::T
Base.wait(t::InferableTask) = wait(t.task)

macro ispawn(ex)
    inferrable(ex) do ex
        :($Threads.@spawn $ex)
    end |> esc
end

macro iasync(ex)
    inferrable(ex) do ex
        :($Base.@async $ex)
    end |> esc
end

function inferrable(spawn_macro, ex)
    @gensym f T world
    quote
        local $f, $T, $world
        $f() = $ex
        $T = $Core.Compiler.return_type($f, $Tuple{})
        $world = $Base.get_world_counter()
        $InferableTask{$T}($(spawn_macro(:($Base.invoke_in_world($world, $f)))))
    end
end

end

It works?

julia> f() = fetch(@ispawn 1+1)
f (generic function with 1 method)

julia> @code_warntype f()
...

Body::Int64
...

@c42f Do you think the above use of invoke_in_world (_apply_in_world) is unsound? Am I wrong to assume that the call to return_type like above uses the world age that would be obtained via get_world_counter? If not, is passing world to return_type fixes it? The above use of invoke_in_world does not require the inference to understand _apply_in_world, right?

2 Likes

Oh right, I see what you had in mind now. I see that you combined return_type with invoke_in_world to get something similar to what world should Tasks run in? Ā· Issue #35690 Ā· JuliaLang/julia Ā· GitHub immediately. Cool!

It looks like the result of return_type is only used as a hint for efficiency and shouldnā€™t otherwise affect the semantics of the program. So thatā€™s good.

Looking at the implementation of return_type, we can see that it fetches the current (dynamically scoped) world counter using jl_get_tls_world_age:

On the other hand, Base.get_world_counter() fetches the global latest world counter which is not what you want. I think you need to replace get_world_counter with a call to jl_get_tls_world_age.

3 Likes

Thanks for reviewing the code! Yeah, #35690 is exactly what I had in mind. I didnā€™t know that there are different kinds of world age.

By the way, Jameson is mentioning that this has a problem because the inference canā€™t add the correct edges (?) Slack

I wonder if Jamesonā€™s concern can be workarounded by a hack like this

diff --git a/src/InferableTasks.jl b/src/InferableTasks.jl
index 0a5705a..06dc89d 100644
--- a/src/InferableTasks.jl
+++ b/src/InferableTasks.jl
@@ -21,11 +21,16 @@ macro iasync(ex)
     end |> esc
 end

+const NEVER = Ref(false)
+
 function inferable(spawn_macro, ex)
     @gensym f T world
     quote
         local $f, $T, $world
-        $f() = $ex
+        $Base.@noinline $f() = $ex
+        if NEVER[]
+            $f()
+        end
         $T = $Core.Compiler.return_type($f, $Tuple{})
         $world = $Base.get_world_counter()
         $InferableTask{$T}($(spawn_macro(:($Base.invoke_in_world($world, $f)))))

Iā€™m hoping this would make sure that the caller of @ispawn would be invalidated if the function $f has to be invalidated. If the caller is invalidated, itā€™ll get a new world age. This in turn, invalidates $f called in the task because it is now called from a new world age (via invoke_in_world).

Ah yes, excellent point about back edges and invalidation. The NEVER hack looks like a clever workaround for that.

Thanks!

So I put things together in https://github.com/tkf/InferableTasks.jl/blob/master/src/InferableTasks.jl just in case someone wants to try it out later.

2 Likes