Compilation latency with constant propagation and unreachable branches

With

julia> function f(a, b)
           if a == 1
               return b * b
           elseif a == 2
               return b * b'
           elseif a == 3
               return b' * b
           else
               return b' * b'
           end
       end
f (generic function with 1 method)

julia> g(f::F, b) where {F} = f(1, b)
g (generic function with 1 method)

julia> f2(a, b) = b * b
f2 (generic function with 1 method)

julia> using LinearAlgebra

julia> A = rand(2,2);

if we run in separate sessions:

julia> @time g(f, A);
  1.482260 seconds (3.43 M allocations: 225.566 MiB, 12.40% gc time, 100.00% compilation time)

julia> @time g(f2, A);
  1.321378 seconds (2.34 M allocations: 152.849 MiB, 14.46% gc time, 100.00% compilation time)

julia> VERSION
v"1.10.3"

we find that the latter has a lower latency and significantly lower allocations. They’re both compiled to the same code and the dead branches are eliminated in g(f, A), but the constant-propagation appears to add quite a bit of overhead. I wonder if there’s a way to make the TTFX for the first case comparable to the second?

Ideally, I wouldn’t want to change the signatures or use static numbers.

1 Like

Constant-propagation is an additional analysis in inference, so it adds overhead to the compiler. However, constant-propagation can also provide better return types and effects, which can ultimately reduce the overall compile latency of the call graph.

1 Like

I wonder if this is a case of us not doing enough constant propagation and us inferring the dead branches, instead of disregarding them.

But we would need to check the inference graph with snoop compile.

1 Like

Are you asking for improvements to the compiler or do want something else?

Tangentially: Is there a simple way of profiling what the compiler does? To get information about what stage takes time, causes allocations etc.?

Yes, I am looking for improvements to the compiler if possible, and I suspect exactly what Valentin suggested above: that the dead branches are being inferred. This is because, in Split generic_matmul for strided matrices into two halves by jishnub · Pull Request #54552 · JuliaLang/julia · GitHub, I find that removing a dead branch significantly reduced latency, which I had not expected.

Not really, for this particular question one could use Snooping on inference: @snoopi_deep · SnoopCompile to get the cost of the inference graph

1 Like

For benchmarking, you can use:

using BaseBenchmarks
BaseBenchmarks.load!("inference")
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call f(args...)

However, please note that this does not necessarily correspond to real-world compile latency. @inf_call only compiles the statically analyzable call graph, so it cannot measure the cost of compilation associated with dynamic calls that occur in reality.

2 Likes

(I posted this on GitHub as well, but I’ll reiterate it here just in case)

To explain briefly, Julia’s inference works as follows:

For a method m:

  1. Regular inference: First, inference is performed without constant information.
  2. Constant inference: Next, inference is performed with constant information for the same m.

These steps are repeated recursively for calls within m.

If a method m contains a branch that becomes dead when its arguments are specific constant values, regular inference in step 1 cannot use that constant information, so it has to infer including that dead branch. This dead branch is identified as dead by constant inference in step 2, and the final generated code is optimized, but we already paid the compilation latency from step 1.

Now you might wonder why we perform regular inference. The reason is that regular inference can be used for any constant values and also even when there is no constant information, making it highly reusable and cache-beneficial.

Therefore, it is difficult to reduce the latency of constant inference for code that can only be optimized by constant inference. A solution, as implemented in this PR, is to refactor the code so that regular inference can recognize dead branches. By splitting methods into cases and using method dispatch to prevent the inference of dead cases, latency can be improved.

2 Likes