XLAServer.jl: serve more models per GPU with Reactant.jl and XLA (gauging interest)

I have been building a Julia inference server and want to find out whether others would find it useful before investing more in polishing it for general use. Feedback, criticism, and “we already have this, it is called X” are all welcome.

The short version

XLAServer.jl is a Julia inference server built to get far more models onto a GPU than its memory would normally hold. It is Julia first and fully extensible in Julia: the compiled model is the fast core, and everything around it, from pre and post-processing to the serving logic itself, is ordinary Julia you can read and shape. Models are compiled ahead of time through Reactant.jl’s PJRT bindings, and the server speaks the KServe V2 inference API over gRPC, so existing Triton and KServe clients connect to it without changes.

It grew out of a concrete migration. We had been serving PyTorch computer-vision models on NVIDIA Triton and decided to move off it. Exporting those models to StableHLO and serving them through XLA gives us whole-program compiler optimization and, more importantly, the ability to keep far more models on each GPU. We kept KServe V2 as the wire protocol so our existing clients did not have to change.

What pushed us was stagnation on both sides. TorchScript, which our deployment leaned on, was deprecated years ago. torch.compile is a real step forward for LLMs, but in our experience it’s essentially useless on the static vision models we run. On the serving side, native support for the newer torch.export path took years to arrive. Compiling to StableHLO and serving through XLA steps around all of that: models from Lux.jl, PyTorch, and Jax become a portable compiled artifact complete with whole program optimizations like kernel fusion.

Key ideas it is built around

Fit more models than GPU memory holds. GPU memory is quickly becoming a dominant cost in inference infrastructure, so serving more models per card can directly lower cost per inference. Because only one model executes at a time, the GPU does not need every model’s weights resident at once. The server materializes every model’s weights into host RAM at startup, then transfers a model’s weights onto the GPU on demand when a request arrives, keeps them resident for reuse, and evicts cold models under a configurable GPU memory budget. Because the weights are already in RAM, an on-demand load is a single host-to-device transfer, on the order of a single inference rather than a reload from disk. The practical effect is that you can cram far more models onto one GPU than its memory would normally hold, paying only a small transfer cost when a cold model is first called. Needing fewer GPUs for the same catalog of models lowers compute cost.

Compiler-grade speed on static graphs, extended with dynamic Julia. Each model is compiled once into an optimized executable. XLA does whole-program optimization, fuses kernels, and plans layout, which is where the Julia ML stack and Reactant give a small team performance they would otherwise have to build by hand. The compiled graph is the fast static core, and you extend it with dynamic logic in plain Julia: data-dependent control flow, custom pre and post-processing, anything that does not belong in a static graph runs as ordinary Julia wrapped around the model. More on how that looks below.

Cost-aware scheduler ties it together Requests land on per-model queues, and a single dispatch loop picks the next model by a deficit-weighted fair-share policy, honors optional per-model latency budgets, and coalesces concurrent requests for the same model into one batched execution, which amortizes fixed per-launch overhead.

Julia first, with a Triton-style model repository

You point the server at a model repository, a directory of model bundles, the same way you point Triton at one. Each bundle is a folder with the compiled model, its weights, a manifest describing the inputs and outputs, and an optional model.jl. The server discovers everything in the repository at startup and exposes the set over the KServe RepositoryIndex call.

The design is Julia first, and the clearest place that shows is pre and post-processing. Those hooks are plain Julia in the bundle’s model.jl, registered with register_model. This earns Julia its place on the hot path: logic that is awkward to express as a static graph, such as data-dependent loops and early exits, is a few lines of ordinary Julia running right next to the model. When we migrated our PyTorch models, the parts that did not export cleanly to a static graph became Julia post-processing instead of being contorted to fit.

A bundle looks like this:

severity_grader/
  manifest.yaml          # input/output spec and compiled batch sizes
  model.b1.mlir          # compiled StableHLO, one module per batch size
  model.b3.mlir
  model.b6.mlir
  weights.safetensors    # shared across batch sizes
  model.jl               # Julia post-processing (optional)
# manifest.yaml
format_version: "2.0"
name: "severity_grader"
executable_inputs:
  - name: "INPUT__0"
    dtype: "u8"
    shape: "whn"            # w width, h height; n is the batch axis (Julia column-major)
    dims: { w: 224, h: 224 }
executable_outputs:        # the model's raw outputs
  - name: "OUTPUT__0"
    dtype: "f32"
    shape: "zn"             # z logits, n batch
    dims: { z: 9 }
client_outputs:            # what the caller sees after model.jl post-processing
  - name: "grade"
    dtype: "i64"
    shape: "yn"             # y is the ordinal class, one per sample
    dims: { y: 1 }
batching:
  compiled_batch_sizes: [1, 3, 6]
# model.jl
using XLAServer: NamedTensor

# The model emits CORAL ordinal-regression logits. The grade is the number of consecutive
# leading logits above zero, stopping at the first that is not. That early break is awkward to
# express as a static XLA graph, but it is a plain loop in Julia, running right next to the model.
function postprocess(out::Vector{NamedTensor})
    logits = out[1].data::Array{Float32}        # OUTPUT__0 logits, (z, batch) in Julia column-major
    Z, B = size(logits)
    grade = zeros(Int64, 1, B)
    @inbounds for b in 1:B
        n = 0
        for i in 1:Z
            logits[i, b] > 0f0 ? (n = i) : break
        end
        grade[1, b] = n
    end
    return NamedTensor[NamedTensor("grade", grade)]
end

register_model("severity_grader"; postprocess=postprocess)

Multiple GPUs and shared memory

Each worker drives one GPU and serves the full KServe API on its own, so a single-GPU deployment needs nothing else. To scale out, you run one worker per GPU and put a gateway in front. The gateway automatically detects which worker is serving which model, using the RepositoryIndex each worker exposes, and presents a single API endpoint to clients. Every request is routed to the worker that holds the requested model with no manual configuration, and the request is forwarded unchanged, so clients see one server rather than a fleet.

For large tensors, the server also implements NVIDIA Triton’s system shared-memory extension. Clients register a shared-memory region and reference it from input and output tensors, so the tensor data never travels over the socket. Existing Triton shared-memory clients work unchanged.

Status

There is a complete path through every layer: load a StableHLO bundle, compile through Reactant/PJRT, schedule, serve over KServe V2 gRPC, return a result. The on-demand weight cache and the scheduler are implemented and tested. Conversion tooling turns a Lux.jl model or a PyTorch nn.Module (via torch.export and torchax) into a server-loadable bundle.

We are also working on Revise.jl support, so the server and the Julia pre/post-processing in model.jl can be edited and hot-reloaded without a full restart while developing.

It is deliberately narrow. It is XLA-centric and static-graph-centric, and it is not trying to be an LLM serving stack, a multi-framework server like Triton, or a hyperscale system. If you need those, they exist and do their jobs well.

What I am asking

  • Is there appetite in the Julia community for an XLA-based, KServe-compatible inference server?
  • If you serve models in production, would simple on-demand weight loading change what fits on your hardware? What does your model mix and traffic pattern look like?