Reduce compilation time by avoiding specialization

I’m trying to speed up function tracing in Umlaut.jl. According to the profiler, about 30% of time is spent in a generic function mkcall(), which records the call, does some argument transformation and then simply invokes the original function:

function mkcall(fn::Any, args::Vararg{Any}; kwargs...)
    fn_, args_ = ...
    return fn_(args_...)
end

Within mkcall() invokation, most time is spent in abstract interpretation & type inference. If I understand it correctly, this is due to Julia compiler specializing mkcall() for each combination of function and arguments:

julia> using MethodAnalysis

julia> methodinstances(mkcall)
7-element Vector{Core.MethodInstance}:
 MethodInstance for Umlaut.mkcall(::typeof(_getfield), ::Variable, ::Int64)
 MethodInstance for Umlaut.mkcall(::typeof(map), ::typeof(unthunk), ::Variable)
 MethodInstance for Umlaut.mkcall(::typeof(tuple), ::Variable, ::Variable)
 MethodInstance for Umlaut.mkcall(::Function, ::Variable, ::Vararg{Variable})
 MethodInstance for Umlaut.mkcall(::Function, ::Function, ::Vararg{Any})
 MethodInstance for Umlaut.mkcall(::typeof(tuple), ::Vararg{Any})
 MethodInstance for Umlaut.mkcall(::Function, ::Variable, ::Vararg{Any})

So I tried to avoid excessive compilation using @nospecialize as well as turning off inlining and constant propagation as suggested here:

@noinline Base.@constprop :none function mkcall(fn::Any, args::Vararg{Any}; kwargs...)
    @nospecialize
    fn_, args_ = ...
    return fn_(args_...)
end

If I then invoke mkcall a few times manually, everything works as expected and Julia generates only one specialization. But when I run it on a real case (specifically, tracing Metalhead.ResNet(18)), I still get a lot of method instances:

julia> methodinstances(mkcall)
6-element Vector{Core.MethodInstance}:
 MethodInstance for Umlaut.mkcall(::typeof(_getfield), ::Variable, ::Int64)
 MethodInstance for Umlaut.mkcall(::Any, ::Any)
 MethodInstance for Umlaut.mkcall(::typeof(map), ::typeof(unthunk), ::Variable)
 MethodInstance for Umlaut.mkcall(::typeof(tuple), ::Variable, ::Variable)
 MethodInstance for Umlaut.mkcall(::typeof(tuple), ::Vararg{Any})
 MethodInstance for Umlaut.mkcall(::Any, ::Any, ::Vararg{Any})

Why @nospecialize doesn’t work in this case? Is there a better way to speed compilation?

Would the inferencebarrier trick described in https://github.com/JuliaLang/julia/pull/41931#issuecomment-902545562 and Why doesn't `@nospecialize` work in this example? - #6 by tim.holy work?

Somehow I overlooked that comment in the thread, thanks for bringing my attention to it!

I wrapped all arguments in all invocations of mkcall with Base.inferencebarrier() like this:

mkcall(Base.inferencebarrier(v_fargs)...; line=Base.inferencebarrier(line))

and this:

mkcall(
    Base.inferencebarrier(getindex),
    Base.inferencebarrier(v),
    Base.inferencebarrier(i);
    line=Base.inferencebarrier("text comment"))
)

But don’t see any effect. Not sure this is the correct usage though as the function is undocumented.

Use inferencebarrier in the caller(s) of mkcall. The idea is to prevent inference from knowing the types of the arguments of mkcall so that it can’t infer-specialize.

An alternative might be to put mkcall in a separate module for which you’ve set Base.Experimental.@compiler_options, but I have not experimented with that enough to predict whether that would avoid the need for inferencebarrier in the caller. If you try it, let us know the result.

I changed all occurrences like this:

function do_something()
    ...
    mkcall(v_fargs...)
end

to this:

function do_something()
    ...
    mkcall(Base.inferencebarrier(v_fargs)...)
end

Is it what you mean by using Base.inferencebarrier() in the caller? If so, even after this change I observe multiple method instances for mkcall. However, since I’m testing it on a pretty complex example, I suspect it’s some weird corner case that I overlooked despite my best effort.


However, your second suggestion is brilliant! Adding

Base.Experimental.@compiler_options optimize=0 compile=min infer=no

to the beginning of the module immediately decreased compilation time from 54 to 44 seconds. Now I don’t see mkcall() in the flame graph at all, and there are no method instances for it after tracing.


What’s interesting, I was able to further reduce tracing time to 32 seconds by changing this:

fargs = (fn, args...)
fargs_ = map_vars(v -> v._op.val, fargs)
fn_, args_ = fargs_[1], fargs_[2:end]
val_ = fn_(args_...)

to this:

fn_ = fn isa V ? fn.op.val : fn
args_ = Any[v isa V ? v.op.val : v for v in args]
val_ = fn_(args_...)

Previously, profiler pointed to fargs_[2:end] (i.e. getindex(::Array, ::UnitRange)) as a major bottleneck.

3 Likes