This code looks like it needs to go in __init__
. Right now it’s at the top level and is probably being executed before __init__
.
That’s a perceptive comment. I hadn’t thought about putting the precompiles in the __init__
but it seems like it should work. In such cases @precompile_setup
will have no impact but @precompile_all_calls
should work as intended. (It’s a bit subtle because __init__
will be compiled before it runs, but since it’s a method in your package all the backedges should be fine for anything that’s not runtime-dispatched, and once the snooping is on it should pick up the runtime-dispatched.)
I’ve been messing around with this and getting some pretty awesome results However, one place it doesn’t seem to do much is for precompiling Zygote gradients. It doesn’t seem to improve TTFG (time to first gradient) by much at all. For some random thing I was just trying e.g. it goes from 13s → 11s, whereas for some of my other non-gradient code I’m getting like 15s → 2s. Is this a known limitation or is there anything to do about this? (This is with 1.8 btw)
This is most likely because Zygote gradients use so many closures, and closures are a different type per session so you cannot precompile them. If those closures were turned into callable structs everywhere in ChainRules.jl, then we’d probably be able to precompile a lot more (@chriselrod we should add this as another reason in SciMLStyle to not use closures ). I don’t know how the ChainRules.jl devs would feel about that kind of style change though: it would be drastic, but it could also improve error messages.
Brought up here: Remove closures for callable types · Issue #657 · JuliaDiff/ChainRules.jl · GitHub We’ll see if we go through with it.
I’ve found a few cases where inference failed on v1.6 and I had to ntuple(i-> foo(x, i), Val{N}())
instead of ntuple(Base.Fix1(foo, x), Val{N}())
to overcome this.
Having looked into this before, most of that 11s is fixed overhead from having to precompile + codegen + run the AD transform code. Trying out SnoopPrecompile with the following MWE:
using Zygote, SnoopCompileCore
f(x) = (2x + 1) / 3 - 4
tinf = @snoopi_deep gradient(f, 1.)
using SnoopCompile
@show tinf
I get:
tinf = InferenceTimingNode: 6.153391/11.348825 on Core.Compiler.Timings.ROOT() with 306 direct children # master
tinf = InferenceTimingNode: 6.098876/6.680918 on Core.Compiler.Timings.ROOT() with 87 direct children # https://github.com/FluxML/Zygote.jl/pull/1281
For a larger MWE (TTFG on Metalhead.ViT
):
InferenceTimingNode: 18.848347/43.728851 on Core.Compiler.Timings.ROOT() with 906 direct children # master
InferenceTimingNode: 18.814999/37.941746 on Core.Compiler.Timings.ROOT() with 532 direct children # PR 1281
So it doesn’t scale, but slashing that constant overhead (and basically all the type inference-related overhead) by 40% is pretty good! If anyone has tips on how to reduce the ~300ms load time increase, please leave them on the PR so we can merge this asap.
Are you sure its quite that? To pick a random rrule
, it does look like its the same gensym-named type every time:
$ julia -e "using ChainRules; println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
ChainRules.var"#Array_pullback#1301"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}
$ julia -e "using ChainRules; println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
ChainRules.var"#Array_pullback#1301"{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}
Or does it need to be more than that to qualify for being the same c.f. precompilation?
In one session, run the rule directly. In the next session, create an anonymous function first and then run the rule. #1301
means 1301 anonymous functions were created first (holy hell). But if you create one and then run the code, you should bump that counter by one first and it’ll be different.
Is it possible to write a macro to automatically convert closures to callable structs? My confusion is that the struct definition should be evaluated in top-level scope rather than inside a parent function. I don’t know if macros can easily do that. Maybe eval
can be emitted by the macro to do exactly this?
I always thought the counter was module-specific, is that not the case? If not, it would not be breaking to make it so, right?
I’m not sure about that.
But note that wouldn’t solve the issue though. There’s still a possibility that someone puts in using
if time() < 500000000; (x) -> x+2; end
and thus in theory even in a module you couldn’t be absolutely certain that the anonymous function counter would be the same every single time you did using
. And if you cannot guarantee that, then you cannot cache the precompile for the type, which is why anonymous functions are essentially an opt-out of precompilation. The inner functions that they call are usually the “meat” of it, so it’s usually okay (what they call inside can still precompile of course, so a closure will have a non-precompiled thing over a precompiled thing, if you snoopprecompile the function that would be closed), but ChainRules is definitely an exception here.
I think the names are given at parse/lowering time, so it doesn’t matter if they’re in unreachable branches. For example:
function test()
if false
x = map(y -> y + 1, 1:10)
else
x = map(y -> y + 2, 1:10)
end
end
@code_lowered test()
test (generic function with 1 method)
CodeInfo(
1 ─ Core.NewvarNode(:(#14))
│ Core.NewvarNode(:(#13))
│ Core.NewvarNode(:(x))
└── goto #3 if not false
2 ─ #13 = %new(Main.:(var"#13#15"))
│ %6 = #13
│ %7 = 1:10
│ %8 = Main.map(%6, %7)
│ x = %8
└── return %8
3 ─ #14 = %new(Main.:(var"#14#16"))
│ %12 = #14
│ %13 = 1:10
│ %14 = Main.map(%12, %13)
│ x = %14
└── return %14
)
Also anecdotally, we precompile the (usually anonymous) listeners functions of observables in Makie.jl, and that seems to work fine.
If you do that at the top level of a module will it lower both?
This is what I thought as well. This seems to confirm it?
julia -e "using ChainRules; println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
julia -e "using ChainRules; (()->1)(); println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
julia -e "using ChainRules; @eval ChainRules (()->1)(); println(typeof(ChainRules.rrule(Array, [1,2,3])[2]))"
# all print same type
So I’m still not sure its not something else more subtle breaking precompilation with Zygote/ChainRule’s closures.
Ok this test is a bit weird but I think it shows the behavior. I define a module with a dead branch with an anonymous function. I can’t call @code_lowered
on that but on another function that I define after. The anonymous function it uses is then called var"#3#4")
:
julia> module M
if false
map(y -> y + 1, 1:10)
end
f(xs) = map(x -> x + 1, xs)
end
Main.M
julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─ #3 = %new(Main.M.:(var"#3#4"))
│ %2 = #3
│ %3 = Main.M.map(%2, xs)
└── return %3
)
If I run this again and replace the module, the anonymous function is still named var"#3#4"
, so it doesn’t depend on global state. The first anonymous function in the dead branch should therefore be var"#1#2
although I don’t know what the two numbers refer to.
julia> module M
if false
map(y -> y + 1, 1:10)
end
f(xs) = map(x -> x + 1, xs)
end
WARNING: replacing module M.
Main.M
julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─ #3 = %new(Main.M.:(var"#3#4"))
│ %2 = #3
│ %3 = Main.M.map(%2, xs)
└── return %3
)
If I add another anonymous function in the dead branch, the name is shifted further down to var"#5#6"
.
julia> module M
if false
map(y -> y + 1, 1:10)
map(y -> y + 1, 1:10)
end
f(xs) = map(x -> x + 1, xs)
end
WARNING: replacing module M.
Main.M
julia> @code_lowered M.f(1:3)
CodeInfo(
1 ─ #5 = %new(Main.M.:(var"#5#6"))
│ %2 = #5
│ %3 = Main.M.map(%2, xs)
└── return %3
)
If I do one more, it’s var"#7#8"
and so on.
Maybe I’m not following your discussion but aren’t changes to the module irrelevant here, since those will trigger precompilation of that module as well as of the dependent and snoop-precompiling other modules? The question for me seem to be whether the types of closures in a precompiled module ever change, and the answer there seems to be no? (such that Chris’ original explanation for the inefficacy of SnoopPrecompile on Zygote doesn’t seem right)
This was about whether things like loading order of things outside modules affect anonymous function names in modules (seems they don’t) and whether logic inside the module could affect the naming (it doesn’t). I think the only way might be to @eval
stuff conditionally?
I just tried this package in MRIReco.jl and my TTFIR (time to first image reconstruction) dropped from 20 seconds to 3.6 seconds. using time went up about 3 seconds. Hence a clear net win. Impressive.