Reactant.jl + Enzyme: ~10 min compilation overhead triggered by hidden-to-hidden Dense layers

I’ve been investigating compilation times with Reactant.jl + Lux + AutoEnzyme() on an NVIDIA GPU and found a reproducible jump in TTFT (Time-To-First-Training) when a model includes hidden-to-hidden Dense layers.

Minimal reproducer

using Reactant, Lux, Random, Optimisers, Enzyme

model = Chain(Dense(1=>16, relu), Dense(16=>16, relu), Dense(16=>1))  # 3 layers — slow
# model = Chain(Dense(1=>16, relu), Dense(16=>1))                    # 2 layers — fast

rng = Xoshiro(999)
dev = reactant_device()

ps, st = Lux.setup(rng, model) |> dev
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
x, y = dev((x, y))

tstate = Training.TrainState(model, ps, st, Adam(0.03f0))

@elapsed _, loss, _, tstate = Training.single_train_step!(
    AutoEnzyme(), MSELoss(), (x, y), tstate
)

Results (each run in a fresh Julia session)

Model Hidden→Hidden layers Parameters First compile Second call
1→16→1 0 49 73s 0.10s
1→256→1 0 513 78s
1→16→16→1 1 321 710s 0.19s
1→16→16→16→1 2 593 712s 0.21s
1→16→16→16→16→1 3 865 730s

Key observations

  1. The overhead is not proportional to parameter count. A 2-layer model with hidden size 256 (513 params) compiles in 78s, while a 3-layer model with hidden size 16 (321 params) takes 710s. The trigger is the presence of hidden-to-hidden layers, not the model size.

  2. Adding more hidden layers beyond the first costs almost nothing. Going from 3 to 5 layers only increases compile time from 710s to 730s. The ~10 min overhead is essentially a fixed cost once any hidden-to-hidden layer exists.

  3. The overhead is in Enzyme, not XLA. Compiling just the forward pass with Reactant.@compile Lux.apply(model, x, ps, st) takes only ~40s for the 3-layer model. The remaining ~10 min comes from Enzyme’s backward pass generation.

  4. It’s not dot-merger. The XLA log shows dot_merger.cc being triggered for ≥3-layer models. However, disabling it with XLA_FLAGS="--xla_disable_hlo_passes=dot-merger" has zero effect on compile time. The dot-merger log message just happens to appear near the end of compilation.

  5. Session-level caching exists. Within the same Julia session, compiling a second model with hidden-to-hidden layers is fast (~22-25s), even if the architecture is different. This suggests Enzyme (or the MLIR pipeline) caches intermediate results from the first compilation. However, a fresh session always pays the full ~10 min cost.

  6. XLA kernel cache doesn’t help. Enabling XLA’s persistent kernel cache with nvlink did not reduce the compilation time for the 3-layer case.

I found this issue (Enzyme.jl #2283) that reports excessive compile times with a similar Lux + AutoEnzyme setup, though without isolating the hidden-to-hidden layer as the trigger.

Environment

  • Julia 1.12.5
  • NVIDIA A100-PCIE-40GB
  • Reactant v0.2.236, Lux v1.31.3, Enzyme v0.13.132, Optimisers v0.4.7

CC @wsmoses @avikpal

1 Like

I narrowed this issue down with CodeGlass to the following function:

.julia/packages/EnzymeCore/RpjpI/src/rules.jl:348
_annotate_tt(TT0::Any)::Tuple{Any, Any}
Calls: 561_060 Total: 10.63 minutes  Allocated: 5.900 gb  Allocations:  91_279_627. 

This function is not even called with the 2 layers model, the whole call path below does not even exist in the 2 layer model.

Going deeper it shows that the time is scattered over numerous Base Module functions, which, on first observation, seem to be mostly related to map operations and resizing of it.

This is the call path to this function, keep in mind Julia Optimzed allot so it is not complete due to inlining. If this is not enough, I can force julia to not inline to get the full one:

{{ Your sample code  }}
/.julia/packages/EnzymeCore/RpjpI/src/rules.jl:364
-> has_rrule_from_sig(::NamedTuple, ::#has_rrule_from_sig, TT::Any)::Bool

.julia/packages/EnzymeCore/RpjpI/src/rules.jl:364
2 Call Sites, that differ on parameter "method_table":

InternalMethodTable: -> #has_rrule_from_sig#5(#has_rrule_from_sig#5::##has_rrule_from_sig#5, world::UInt64, method_table::InternalMethodTable, caller::Nothing, ::#has_rrule_from_sig, TT::Any)::Bool
Calls: 184_399 Total: 10.27 minutes Allocated: 5.616 gb Allocations: 83_537_838

OverlayMethodTable: -> #has_rrule_from_sig#5(#has_rrule_from_sig#5::##has_rrule_from_sig#5, world::UInt64, method_table::OverlayMethodTable, caller::Nothing, ::#has_rrule_from_sig, TT::Any)::Bool
Calls: 376_661 Total: 0.63 minutes  Allocated: 0.689 gb Allocations: 7_741_789

.julia/packages/EnzymeCore/RpjpI/src/rules.jl:348
-> _annotate_tt(TT0::Any)::Tuple{Any, Any}
Calls: 561_060 Total: 10.63 minutes  Allocated: 5.900 gb  Allocations:  91_279_627. 

Let me know if you need more, CodeGlass has allot more detailed information e.g. what types and where the 91_279_627 allocations are from etc.

CodeGlass is a new instrumentation Layer for Julia in support of companies that need this to solve (production) related problems such as these. So this was another great test for us :smiley:

2 Likes