--- Body:
I’ve been trying to train a small descriptor CNN (~1M params, 64×64
grayscale patch input → 128-dim L2-normalized embedding) with a
Supervised Contrastive (SupCon) loss in pure Julia / Lux on a single
B200 GPU, and have hit what looks like a fundamental compile-time
wall in every Julia code path I’ve tried. I’d appreciate a sanity
check on whether this is expected for the stack, or whether I’m
holding something wrong.
Versions: Julia 1.12.5, Lux 1.31.3, Reactant 0.2.x, Enzyme 0.13.134,
Zygote 0.7, CUDA 5.x, cuDNN 1.4.7 (artifact 9.20). All running under
SLURM sbatch on a fresh process (no warm REPL).
Workload: - Data: 3.6M training patches (66 GB on disk as JLD2), 56k
blob-identity classes, plus 709k confuser patches. - Batch
composition (tuned for HBM3e): 384 anchor blobs × 10 views + 2048
confusers = 5888 patches per step. Patch size 64×64 Float32. Network
is a flat Chain of ~25 layers (Conv→BatchNorm pattern × 6 with 2×
downsamples + head). - Loss: vectorized SupCon — S = embeddings’ *
embeddings, build (N,N) Bool masks for self/positive, log-sum-exp
per row, average log-prob over positives. No scalar indexing, no
mutation, no Python-style loops.
The wall — three independent paths, all hung past the first step:
- Reactant XLA + AutoEnzyme via single_train_step! (the path Lux’s
official ResNet20 tutorial uses): driver entered
single_train_step!, Reactant’s XLA service initialized cleanly on
the B200, BFC pre-allocated 143 GB on device 0, cuDNN 9.14
loaded. Then 45+ minutes of 100% single-thread CPU, no GPU memory
ever allocated, no first-step output. Backtrace via SIGUSR1
confirmed it was inside
single_train_step_impl_with_allocator_cache! — pure Reactant
tracing/lowering of the SupCon backward graph. 2. Eager CUDA +
Zygote, manual Zygote.withgradient + Optimisers.update! loop (per
Lux #1704 maintainer-suggested workaround for single_train_step!
overhead): same wall. 1+ hour 100% CPU, RSS plateaued at 92 GB
after data load, no GPU memory, no first step. Backtrace showed
pure type-inference / LLVM codegen with no GC yields between
SIGUSR1 and observation 50 minutes later. 3. Reactant @compile
model(x, ps, st) forward warm-up + single_train_step!(AutoEnzyme())
(matches the tutorial pattern more precisely): the @compile warm-up
actually finished (~17 min). Then single_train_step!'s first call
started compiling the backward and again sat at 100% CPU with no
output for the next 40+ minutes. Caching the forward HLO doesn’t
shortcut Enzyme’s backward derivation.
Things I’ve ruled out: - It’s not a hang — ps shows 99-100% CPU
steadily, RSS stable, the process is making progress, just very
slowly. - Not a data-load issue — Loading training data… and
Model: 1057696 params, batch=5888, … both print before the wall.
- Not the loss — supcon_loss_mat is fully vectorized broadcast over
(N,N) matrices, no scalar indexing, AD-friendly. - Not BatchNorm
specifically (tutorial uses BN too). - Not nested Chains
specifically (I flattened them per the ResNet20 tutorial pattern —
same wall). - Not Lux.Training specifically (manual Zygote.gradient
- Optimisers.update! hits the same wall).
What seems to be the common factor: the cost of generating
reverse-mode code for (::Function, ::TrainState{Model, Params,
State, Opt, OptState}) (or the equivalent closure type for the
manual loop), where the captured model is a Chain of ~25 layers and
Params/State are deeply nested NamedTuples. Lux #1484 documents a
200k-param parallel-CNN model taking 11+ min via single_train_step!
- Reactant + Enzyme; we’re at 1M params with a similarity-matrix
loss on top, so 60+ min isn’t even surprising in that context. Lux
#1704 explicitly acknowledges single_train_step!'s
shape-specialization overhead.
Searches I’ve done that didn’t help: - No Julia/Flux/Lux
implementation of SupCon, NTXent, or InfoNCE exists in any FluxML or
LuxDL repo, model zoo, or tutorial. Flux only has the pairwise
siamese_contrastive_loss. The contrastive learning workload appears
genuinely unimplemented in pure Julia. - PackageCompiler + CUDA +
Lux has documented historical issues (Flux #1337 — sysimage GPU
access violations, Discourse 2020 Zygote/PackageCompiler
incompatibility). The Julia community has broadly moved to
PrecompileTools, but those caches don’t cover our specific call
signature. - Reactant #1990 flags Lux + Conv layers + Julia 1.12 as
fragile (different error mode than ours, but same combo).
The questions I’d love community input on:
- Is anyone actually training a contrastive / large-batch metric
learning model in Julia (Lux or Flux) end-to-end on GPU? If yes —
what stack, what model size, what batch size, how long does
first-step compile take? 2. Are there model-architecture or
loss-formulation idioms that specifically avoid the compile-time
wall on similarity-matrix losses, beyond “smaller batch, fewer
layers”? 3. Is there a way to get Reactant to cache the compiled
training step across processes (sysimage equivalent for HLO), so we
pay the wall once per machine rather than once per sbatch? @compile
artifacts seem to be in-process only. 4. Is the realistic answer
“use PythonCall + PyTorch for this specific workload”? It feels like
it given the search evidence, but I’d love to hear if anyone has
gotten contrastive training working in pure Julia and what it took.
Setup, model code, loss code, and full backtraces available on
request — happy to share a minimal repro if the symptoms sound
resolvable.
Thanks in advance.