Having looked into this before, most of that 11s is fixed overhead from having to precompile + codegen + run the AD transform code. Trying out SnoopPrecompile with the following MWE:
tinf = InferenceTimingNode: 6.153391/11.348825 on Core.Compiler.Timings.ROOT() with 306 direct children # master
tinf = InferenceTimingNode: 6.098876/6.680918 on Core.Compiler.Timings.ROOT() with 87 direct children # https://github.com/FluxML/Zygote.jl/pull/1281
For a larger MWE (TTFG on Metalhead.ViT):
InferenceTimingNode: 18.848347/43.728851 on Core.Compiler.Timings.ROOT() with 906 direct children # master
InferenceTimingNode: 18.814999/37.941746 on Core.Compiler.Timings.ROOT() with 532 direct children # PR 1281
So it doesn’t scale, but slashing that constant overhead (and basically all the type inference-related overhead) by 40% is pretty good! If anyone has tips on how to reduce the ~300ms load time increase, please leave them on the PR so we can merge this asap.
In one session, run the rule directly. In the next session, create an anonymous function first and then run the rule. #1301 means 1301 anonymous functions were created first (holy hell). But if you create one and then run the code, you should bump that counter by one first and it’ll be different.
Is it possible to write a macro to automatically convert closures to callable structs? My confusion is that the struct definition should be evaluated in top-level scope rather than inside a parent function. I don’t know if macros can easily do that. Maybe eval can be emitted by the macro to do exactly this?
But note that wouldn’t solve the issue though. There’s still a possibility that someone puts in usingif time() < 500000000; (x) -> x+2; end and thus in theory even in a module you couldn’t be absolutely certain that the anonymous function counter would be the same every single time you did using. And if you cannot guarantee that, then you cannot cache the precompile for the type, which is why anonymous functions are essentially an opt-out of precompilation. The inner functions that they call are usually the “meat” of it, so it’s usually okay (what they call inside can still precompile of course, so a closure will have a non-precompiled thing over a precompiled thing, if you snoopprecompile the function that would be closed), but ChainRules is definitely an exception here.
This is what I thought as well. This seems to confirm it?
julia -e "using ChainRules; println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
julia -e "using ChainRules; (()->1)(); println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
julia -e "using ChainRules; @eval ChainRules (()->1)(); println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
# all print same type
So I’m still not sure its not something else more subtle breaking precompilation with Zygote/ChainRule’s closures.
Ok this test is a bit weird but I think it shows the behavior. I define a module with a dead branch with an anonymous function. I can’t call @code_lowered on that but on another function that I define after. The anonymous function it uses is then called var"#3#4"):
julia> module M
if false
map(y -> y + 1, 1:10)
end
f(xs) = map(x -> x + 1, xs)
end
Main.M
julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─ #3 = %new(Main.M.:(var"#3#4"))
│ %2 = #3
│ %3 = Main.M.map(%2, xs)
└── return %3
)
If I run this again and replace the module, the anonymous function is still named var"#3#4", so it doesn’t depend on global state. The first anonymous function in the dead branch should therefore be var"#1#2 although I don’t know what the two numbers refer to.
julia> module M
if false
map(y -> y + 1, 1:10)
end
f(xs) = map(x -> x + 1, xs)
end
WARNING: replacing module M.
Main.M
julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─ #3 = %new(Main.M.:(var"#3#4"))
│ %2 = #3
│ %3 = Main.M.map(%2, xs)
└── return %3
)
If I add another anonymous function in the dead branch, the name is shifted further down to var"#5#6".
julia> module M
if false
map(y -> y + 1, 1:10)
map(y -> y + 1, 1:10)
end
f(xs) = map(x -> x + 1, xs)
end
WARNING: replacing module M.
Main.M
julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─ #5 = %new(Main.M.:(var"#5#6"))
│ %2 = #5
│ %3 = Main.M.map(%2, xs)
└── return %3
)
Maybe I’m not following your discussion but aren’t changes to the module irrelevant here, since those will trigger precompilation of that module as well as of the dependent and snoop-precompiling other modules? The question for me seem to be whether the types of closures in a precompiled module ever change, and the answer there seems to be no? (such that Chris’ original explanation for the inefficacy of SnoopPrecompile on Zygote doesn’t seem right)
This was about whether things like loading order of things outside modules affect anonymous function names in modules (seems they don’t) and whether logic inside the module could affect the naming (it doesn’t). I think the only way might be to @eval stuff conditionally?
I just tried this package in MRIReco.jl and my TTFIR (time to first image reconstruction) dropped from 20 seconds to 3.6 seconds. using time went up about 3 seconds. Hence a clear net win. Impressive.