ANN: New package SnoopPrecompile

Having looked into this before, most of that 11s is fixed overhead from having to precompile + codegen + run the AD transform code. Trying out SnoopPrecompile with the following MWE:

using Zygote, SnoopCompileCore

f(x) = (2x + 1) / 3 - 4

tinf = @snoopi_deep gradient(f, 1.)

using SnoopCompile

@show tinf

I get:

tinf = InferenceTimingNode: 6.153391/11.348825 on Core.Compiler.Timings.ROOT() with 306 direct children # master
tinf = InferenceTimingNode: 6.098876/6.680918 on Core.Compiler.Timings.ROOT() with 87 direct children # https://github.com/FluxML/Zygote.jl/pull/1281

For a larger MWE (TTFG on Metalhead.ViT):

InferenceTimingNode: 18.848347/43.728851 on Core.Compiler.Timings.ROOT() with 906 direct children # master
InferenceTimingNode: 18.814999/37.941746 on Core.Compiler.Timings.ROOT() with 532 direct children # PR 1281

So it doesn’t scale, but slashing that constant overhead (and basically all the type inference-related overhead) by 40% is pretty good! If anyone has tips on how to reduce the ~300ms load time increase, please leave them on the PR so we can merge this asap.

3 Likes

Are you sure its quite that? To pick a random rrule, it does look like its the same gensym-named type every time:

$ julia -e "using ChainRules; println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
ChainRules.var"#Array_pullback#1301"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}
$ julia -e "using ChainRules; println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
ChainRules.var"#Array_pullback#1301"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}

Or does it need to be more than that to qualify for being the same c.f. precompilation?

In one session, run the rule directly. In the next session, create an anonymous function first and then run the rule. #1301 means 1301 anonymous functions were created first (holy hell). But if you create one and then run the code, you should bump that counter by one first and it’ll be different.

2 Likes

Is it possible to write a macro to automatically convert closures to callable structs? My confusion is that the struct definition should be evaluated in top-level scope rather than inside a parent function. I don’t know if macros can easily do that. Maybe eval can be emitted by the macro to do exactly this?

I always thought the counter was module-specific, is that not the case? If not, it would not be breaking to make it so, right?

I’m not sure about that.

But note that wouldn’t solve the issue though. There’s still a possibility that someone puts in using if time() < 500000000; (x) -> x+2; end and thus in theory even in a module you couldn’t be absolutely certain that the anonymous function counter would be the same every single time you did using. And if you cannot guarantee that, then you cannot cache the precompile for the type, which is why anonymous functions are essentially an opt-out of precompilation. The inner functions that they call are usually the “meat” of it, so it’s usually okay (what they call inside can still precompile of course, so a closure will have a non-precompiled thing over a precompiled thing, if you snoopprecompile the function that would be closed), but ChainRules is definitely an exception here.

I think the names are given at parse/lowering time, so it doesn’t matter if they’re in unreachable branches. For example:

function test()
    if false
        x = map(y -> y + 1, 1:10)
    else
        x = map(y -> y + 2, 1:10)
    end
end

@code_lowered test()
test (generic function with 1 method)

CodeInfo(
1 ─       Core.NewvarNode(:(#14))
│         Core.NewvarNode(:(#13))
│         Core.NewvarNode(:(x))
└──       goto #3 if not false
2 ─       #13 = %new(Main.:(var"#13#15"))
│   %6  = #13
│   %7  = 1:10
│   %8  = Main.map(%6, %7)
│         x = %8
└──       return %8
3 ─       #14 = %new(Main.:(var"#14#16"))
│   %12 = #14
│   %13 = 1:10
│   %14 = Main.map(%12, %13)
│         x = %14
└──       return %14
)

Also anecdotally, we precompile the (usually anonymous) listeners functions of observables in Makie.jl, and that seems to work fine.

If you do that at the top level of a module will it lower both?

This is what I thought as well. This seems to confirm it?

julia -e "using ChainRules; println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
julia -e "using ChainRules; (()->1)(); println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
julia -e "using ChainRules; @eval ChainRules (()->1)(); println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
# all print same type

So I’m still not sure its not something else more subtle breaking precompilation with Zygote/ChainRule’s closures.

1 Like

Ok this test is a bit weird but I think it shows the behavior. I define a module with a dead branch with an anonymous function. I can’t call @code_lowered on that but on another function that I define after. The anonymous function it uses is then called var"#3#4"):

julia> module M
           if false
               map(y -> y + 1, 1:10)
           end
           f(xs) = map(x -> x + 1, xs)
       end
Main.M

julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─      #3 = %new(Main.M.:(var"#3#4"))
│   %2 = #3
│   %3 = Main.M.map(%2, xs)
└──      return %3
)

If I run this again and replace the module, the anonymous function is still named var"#3#4", so it doesn’t depend on global state. The first anonymous function in the dead branch should therefore be var"#1#2 although I don’t know what the two numbers refer to.

julia> module M
           if false
               map(y -> y + 1, 1:10)
           end
           f(xs) = map(x -> x + 1, xs)
       end
WARNING: replacing module M.
Main.M

julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─      #3 = %new(Main.M.:(var"#3#4"))
│   %2 = #3
│   %3 = Main.M.map(%2, xs)
└──      return %3
)

If I add another anonymous function in the dead branch, the name is shifted further down to var"#5#6".

julia> module M
           if false
               map(y -> y + 1, 1:10)
               map(y -> y + 1, 1:10)
           end
           f(xs) = map(x -> x + 1, xs)
       end
WARNING: replacing module M.
Main.M

julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─      #5 = %new(Main.M.:(var"#5#6"))
│   %2 = #5
│   %3 = Main.M.map(%2, xs)
└──      return %3
)

If I do one more, it’s var"#7#8" and so on.

2 Likes

Maybe I’m not following your discussion but aren’t changes to the module irrelevant here, since those will trigger precompilation of that module as well as of the dependent and snoop-precompiling other modules? The question for me seem to be whether the types of closures in a precompiled module ever change, and the answer there seems to be no? (such that Chris’ original explanation for the inefficacy of SnoopPrecompile on Zygote doesn’t seem right)

This was about whether things like loading order of things outside modules affect anonymous function names in modules (seems they don’t) and whether logic inside the module could affect the naming (it doesn’t). I think the only way might be to @eval stuff conditionally?

I just tried this package in MRIReco.jl and my TTFIR (time to first image reconstruction) dropped from 20 seconds to 3.6 seconds. using time went up about 3 seconds. Hence a clear net win. Impressive.

10 Likes

8 posts were split to a new topic: Question about using SnoopPrecompile