Reactant.jl + Enzyme: ~10 min compilation overhead triggered by hidden-to-hidden Dense layers

I’ve been investigating compilation times with Reactant.jl + Lux + AutoEnzyme() on an NVIDIA GPU and found a reproducible jump in TTFT (Time-To-First-Training) when a model includes hidden-to-hidden Dense layers.

Minimal reproducer

using Reactant, Lux, Random, Optimisers, Enzyme

model = Chain(Dense(1=>16, relu), Dense(16=>16, relu), Dense(16=>1))  # 3 layers — slow
# model = Chain(Dense(1=>16, relu), Dense(16=>1))                    # 2 layers — fast

rng = Xoshiro(999)
dev = reactant_device()

ps, st = Lux.setup(rng, model) |> dev
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
x, y = dev((x, y))

tstate = Training.TrainState(model, ps, st, Adam(0.03f0))

@elapsed _, loss, _, tstate = Training.single_train_step!(
    AutoEnzyme(), MSELoss(), (x, y), tstate
)

Results (each run in a fresh Julia session)

Model Hidden→Hidden layers Parameters First compile Second call
1→16→1 0 49 73s 0.10s
1→256→1 0 513 78s
1→16→16→1 1 321 710s 0.19s
1→16→16→16→1 2 593 712s 0.21s
1→16→16→16→16→1 3 865 730s

Key observations

  1. The overhead is not proportional to parameter count. A 2-layer model with hidden size 256 (513 params) compiles in 78s, while a 3-layer model with hidden size 16 (321 params) takes 710s. The trigger is the presence of hidden-to-hidden layers, not the model size.

  2. Adding more hidden layers beyond the first costs almost nothing. Going from 3 to 5 layers only increases compile time from 710s to 730s. The ~10 min overhead is essentially a fixed cost once any hidden-to-hidden layer exists.

  3. The overhead is in Enzyme, not XLA. Compiling just the forward pass with Reactant.@compile Lux.apply(model, x, ps, st) takes only ~40s for the 3-layer model. The remaining ~10 min comes from Enzyme’s backward pass generation.

  4. It’s not dot-merger. The XLA log shows dot_merger.cc being triggered for ≥3-layer models. However, disabling it with XLA_FLAGS="--xla_disable_hlo_passes=dot-merger" has zero effect on compile time. The dot-merger log message just happens to appear near the end of compilation.

  5. Session-level caching exists. Within the same Julia session, compiling a second model with hidden-to-hidden layers is fast (~22-25s), even if the architecture is different. This suggests Enzyme (or the MLIR pipeline) caches intermediate results from the first compilation. However, a fresh session always pays the full ~10 min cost.

  6. XLA kernel cache doesn’t help. Enabling XLA’s persistent kernel cache with nvlink did not reduce the compilation time for the 3-layer case.

I found this issue (Enzyme.jl #2283) that reports excessive compile times with a similar Lux + AutoEnzyme setup, though without isolating the hidden-to-hidden layer as the trigger.

Environment

  • Julia 1.12.5
  • NVIDIA A100-PCIE-40GB
  • Reactant v0.2.236, Lux v1.31.3, Enzyme v0.13.132, Optimisers v0.4.7

CC @wsmoses @avikpal