I’ve been investigating compilation times with Reactant.jl + Lux + AutoEnzyme() on an NVIDIA GPU and found a reproducible jump in TTFT (Time-To-First-Training) when a model includes hidden-to-hidden Dense layers.
Minimal reproducer
using Reactant, Lux, Random, Optimisers, Enzyme
model = Chain(Dense(1=>16, relu), Dense(16=>16, relu), Dense(16=>1)) # 3 layers — slow
# model = Chain(Dense(1=>16, relu), Dense(16=>1)) # 2 layers — fast
rng = Xoshiro(999)
dev = reactant_device()
ps, st = Lux.setup(rng, model) |> dev
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
x, y = dev((x, y))
tstate = Training.TrainState(model, ps, st, Adam(0.03f0))
@elapsed _, loss, _, tstate = Training.single_train_step!(
AutoEnzyme(), MSELoss(), (x, y), tstate
)
Results (each run in a fresh Julia session)
| Model | Hidden→Hidden layers | Parameters | First compile | Second call |
|---|---|---|---|---|
1→16→1 |
0 | 49 | 73s | 0.10s |
1→256→1 |
0 | 513 | 78s | — |
1→16→16→1 |
1 | 321 | 710s | 0.19s |
1→16→16→16→1 |
2 | 593 | 712s | 0.21s |
1→16→16→16→16→1 |
3 | 865 | 730s | — |
Key observations
-
The overhead is not proportional to parameter count. A 2-layer model with hidden size 256 (513 params) compiles in 78s, while a 3-layer model with hidden size 16 (321 params) takes 710s. The trigger is the presence of hidden-to-hidden layers, not the model size.
-
Adding more hidden layers beyond the first costs almost nothing. Going from 3 to 5 layers only increases compile time from 710s to 730s. The ~10 min overhead is essentially a fixed cost once any hidden-to-hidden layer exists.
-
The overhead is in Enzyme, not XLA. Compiling just the forward pass with
Reactant.@compile Lux.apply(model, x, ps, st)takes only ~40s for the 3-layer model. The remaining ~10 min comes from Enzyme’s backward pass generation. -
It’s not
dot-merger. The XLA log showsdot_merger.ccbeing triggered for ≥3-layer models. However, disabling it withXLA_FLAGS="--xla_disable_hlo_passes=dot-merger"has zero effect on compile time. Thedot-mergerlog message just happens to appear near the end of compilation. -
Session-level caching exists. Within the same Julia session, compiling a second model with hidden-to-hidden layers is fast (~22-25s), even if the architecture is different. This suggests Enzyme (or the MLIR pipeline) caches intermediate results from the first compilation. However, a fresh session always pays the full ~10 min cost.
-
XLA kernel cache doesn’t help. Enabling XLA’s persistent kernel cache with nvlink did not reduce the compilation time for the 3-layer case.
I found this issue (Enzyme.jl #2283) that reports excessive compile times with a similar Lux + AutoEnzyme setup, though without isolating the hidden-to-hidden layer as the trigger.
Environment
- Julia 1.12.5
- NVIDIA A100-PCIE-40GB
- Reactant v0.2.236, Lux v1.31.3, Enzyme v0.13.132, Optimisers v0.4.7