Significantly Higher VRAM Usage and Slower Training on Flux Compared to PyTorch

I’ve been working on a project where I’m converting Timm models to Flux along with their pre-trained weights. Everything is going well so far, but I noticed that Flux has significantly higher VRAM consumption and takes about twice as long to train compared to PyTorch.

To confirm this, I wrote a simple PyTorch and Flux script that trains a ResNet-34 model on 8,000 samples from the AID dataset. To make sure the difference wasn’t down to my own mistakes, I used the ResNet implementation from Metalhead.

Here are the respective scripts:

PyTorch

class TimmClassifier(pl.LightningModule):

    def __init__(self, model:str, num_classes: int, learning_rate: float = 1e-3, pretrained: bool = True):
        super().__init__()
        self.save_hyperparameters()
        self.model = timm.create_model(
            model, pretrained=pretrained, num_classes=num_classes
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def _shared_step(self, batch, stage: str):
        images, labels = batch
        logits = self(images)
        loss = self.loss_fn(logits, labels)
        preds = logits.argmax(dim=1)
        acc = (preds == labels).float().mean()
        self.log(f"{stage}_loss", loss, prog_bar=True)
        self.log(f"{stage}_acc", acc, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        self._shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return [optimizer], [scheduler]


def train_resnet34(
    dataset,
    num_classes: int,
    val_split: float = 0.2,
    batch_size: int = 16,
    max_epochs: int = 20,
    learning_rate: float = 1e-4,
    pretrained: bool = True,
    num_workers: int = 4,
    accelerator: str = "auto",
):
    # --- Split dataset ---
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True,
    )

    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True,
    )

    # --- Model ---
    model = TimmClassifier(
        model="resnet34",
        num_classes=num_classes,
        learning_rate=learning_rate,
        pretrained=pretrained,
    )

    # --- Callbacks ---
    checkpoint_cb = ModelCheckpoint(
        monitor="val_loss", mode="min", save_top_k=1, filename="best"
    )
    early_stop_cb = EarlyStopping(monitor="val_loss", patience=5, mode="min")

    # --- Trainer ---
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator=accelerator,
        callbacks=[checkpoint_cb, early_stop_cb],
        log_every_n_steps=10,
        precision="32-true", 
    )
    trainer.fit(model, train_loader, val_loader)


def main():
    dataset = AID(path="../../Data/AID")
    train_resnet34(dataset, num_classes=len(dataset.classes), pretrained=False, max_epochs=5)

Flux

struct FineTuneEncoder{E} <: FluxModule
    encoder::E
end

function FineTuneEncoder(num_classes::Int)
    encoder = Metalhead.ResNet(34, pretrain=false, nclasses=num_classes)
    return FineTuneEncoder(encoder)
end

Flux.Optimisers.trainable(x::FineTuneEncoder) = (; x.encoder)


(model::FineTuneEncoder)(x) = model.encoder(x)

function loss_and_accuracy(model::FineTuneEncoder, batch)
    x, y = batch
    ŷ = model(x)
    return Flux.logitcrossentropy(ŷ, y), Tsunami.accuracy(ŷ, y)
end

function Tsunami.train_step(model::FineTuneEncoder, trainer, batch)
    loss, acc = loss_and_accuracy(model, batch)
    Tsunami.log(trainer, "loss/train", loss, prog_bar=true)
    Tsunami.log(trainer, "accuracy/train", acc, prog_bar=true)
    return loss
end

function Tsunami.val_step(model::FineTuneEncoder, trainer, batch)
    loss, acc = loss_and_accuracy(model, batch)
    Tsunami.log(trainer, "loss/val", loss)
    Tsunami.log(trainer, "accuracy/val", acc)
end

function Tsunami.configure_optimisers(m::FineTuneEncoder, trainer)
    opt_rule = Flux.Optimisers.Adam(1e-4)
    opt_state = Flux.Optimisers.setup(opt_rule, m)
    return opt_state
end

function run_train()
    # Prepare the dataset
    dataset = ImageDataset("../../Data/AID", imsize=(224, 224))
    train_data, test_data = Flux.MLUtils.splitobs(dataset, at=0.8, shuffle=true)
    train_loader = Flux.DataLoader(train_data, batchsize=16, collate=true, parallel=true, shuffle=true)
    test_loader = Flux.DataLoader(test_data, batchsize=16, collate=true, parallel=true)

    # Create and train the model
    model = FineTuneEncoder(length(dataset.labels))
    trainer = Trainer(max_epochs=5)
    Tsunami.fit!(model, trainer, train_loader, test_loader)
end

These are the results for ResNet-34 on a machine with 64 GB of RAM and an NVIDIA RTX 5090 with 32 GB of VRAM. VRAM usage is reported by nvtop.

Flux: 19.1 GB VRAM - 11 seconds / epoch

PyTorch: 2.6 GB VRAM - 7 seconds / epoch

I also tried another experiment with a ViT base model, which produced the following.

Flux: 16.2 GB VRAM - 38 seconds / epoch

PyTorch: 5.1 GB VRAM - 23 seconds / epoch

To determine if Julia was perhaps just allocating more buffer space without actually using it, I tried to increase the batch size until I got an OOM. PyTorch was able to handle a batch size of 128 with ViT base, while Flux could only go up to 24 after throwing OOM errors at 32.

Does anyone have any idea what might be causing this discrepancy?

I would test two things :

Se how this is done in Flux.train!

It looks like Tsunami.jl already caches with GPUArrays, which is what I’m using for training:

## SINGLE EPOCH TRAINING LOOP
for (batch_idx, batch) in enumerate(train_dataloader)
    fit_state.step += 1
    fit_state.batchsize = MLUtils.numobs(batch)

    hook(on_train_batch_start, model, trainer, batch, batch_idx)
        
    GPUArrays.@cached trainer.cache begin
        out, grad = gradient_train_step(model, trainer, batch, batch_idx)
    end
        
    hook(on_before_update, model, trainer, out, grad)
        
    GPUArrays.@cached trainer.cache begin 
        update!(trainer.optimisers, model, grad)
    end

    if fit_state.step == trainer.max_steps
        fit_state.should_stop = true
    end

    hook(on_train_batch_end, model, trainer, out, batch, batch_idx)
        
    ProgressMeter.next!(train_progbar,
        showvalues = values_for_train_progbar(trainer.metalogger),
        valuecolor = :yellow, 
        final = fit_state.should_stop || batch_idx == _length(train_dataloader),
        keep = fit_state.should_stop || fit_state.epoch == trainer.max_epochs
    )

    fit_state.should_stop && break
end

I’ll try re-writing the training loop to use Reactant to see if there’s any improvement.

In my experience w/ Lux.jl, it seems to allocate more than it actually uses. A PyTorch model that I know takes around 3GB for forwards/backwards pass w/ the same batch size shows 24GB when training it with Lux.jl. When I increase the batch size by 8x nvidia-smi then reports its using 48GB :thinking:

Yes that’s why we’re so happy to have Reactant now, I wonder how much respecting Zygote way forced this at the time those layer were written

The thing that’s likely keeping many people from using Reactant.jl is the documentation. It’s not totally clear how to do things like DDP style training, how to use Reactant in practice when also using SciMLSensitivity.jl, etc.

That being said it’s a very promising direction for ML in Julia. Some very cool stuff is possible like being able to directly load PyTorch models via StableHLO: PyTorch StableHLO Support · Issue #2065 · EnzymeAD/Reactant.jl · GitHub

Yes but with SciMLSensitivity it’s pretty rare you need GPU and for CPU small network Julia + Enzyme or Mooncake is largely enough they are some cases when you need a big network (PDE + spectral solver is one ) but they are pretty rare actually.

For now the best thing to do is to look at all the already implemented Lux+Reactant cases learning by example.

I think Reactant may change a lot before 1.0 though which explains not to hurry building a big doc for now, for instance

  • If the @trace can be remove
  • if switching option become more user friendly
  • if the backend could be chosen on the fly so that is simplify other repo using reactant (edit : it’s always been here )
  • if Metal gets added
    Ect ect ect

It’s funny though that even early like that I prefer Julia + Reactant way of handling MLIR than Jax or Mojo

Docs are always welcome! If there’s somethign missing do let us know.

I have no idea what “DDP style training” is so that’s at least one reason why that’s not in a doc xD

As for in progress features, yeah automatic removal of the need of @ trace is something we’re looking into, and there is an open PR adding Metal (and we just added initial trainium support too).

I’m curious what you mean by switching/backend issues?

I think I saw it somewhere (saw someone ask or an issue) to be able at compile option and/or array creation explicitly ask for it to be on CPU/GPU/TPU instead of / in addition to the global set_backend way. In a user perceptive not adding that much but for another package wanting to dépend on Reactant that may be nice to have (especially if the package already dispatch on the KernelAbstraction backend selector).
Also would be very cool to have a talk like https://m.youtube.com/watch?v=XuMDzRmRPPQ&pp=ygUPSnVsaWFodWIgQ3VUaWxl for Reactant I’m sure JuliaHub would be ok to do that ? The JuliaCon are great but limited in term of complexity and length

to_rarray [and similar] take an optional either device, or backend:
Core Reactant API | Reactant.jl , defaulting to the global if not specified.

I think they’ve had those args since the API itself was created, so maybe it’s just not as well documented (help welcome!)? Or are you thinking of something else?

Oh sorry yes I didn’t know it and people may not know indeed. I will make a doc pr about it when I can Thank you

I re-implemented my training script to use Reactant and Lux:

function train_lux()
    lux_model = Lux.Chain(
        #Boltz.Vision.ResNet(34), 
        Boltz.Vision.ViT(:base), 
        Lux.Dense(1000, 30), 
    )

    rng = Random.default_rng()
    ps, st = Lux.setup(rng, lux_model)

    # Move to Reactant device
    dev = Lux.reactant_device()
    cdev = Lux.cpu_device()
    ps, st = ps |> dev, st |> dev

    imsize = 256
    dataset = ImageDataset("../../Data/AID", imsize=(imsize, imsize))
    train_data, test_data = Flux.MLUtils.splitobs(dataset, at=0.8, shuffle=true)
    train_loader = Flux.DataLoader(train_data, batchsize=128, shuffle=true, collate=true, parallel=true, partial=false) |> dev
    test_loader = Flux.DataLoader(test_data, batchsize=128, shuffle=true, collate=true, parallel=true, partial=false) |> dev

    model_compiled = Reactant.@compile lux_model(first(train_loader)[1], ps, Lux.testmode(st))

    opt = Optimisers.Adam(1f-4)
    tstate = Lux.Training.TrainState(lux_model, ps, st, opt)

    # 2. Training Loop
    loss_fn = Lux.CrossEntropyLoss(;logits=Val(true)) # or your custom loss function

    for epoch in 1:10
        @info "Epoch $epoch"
        total_loss = 0.0
        ProgressMeter.@showprogress for (xdata, ydata) in train_loader
            _, loss, _, tstate = Lux.Training.single_train_step!(
                Lux.AutoEnzyme(), loss_fn, (xdata, ydata), tstate
            )
            total_loss += loss
        end
        @info "loss:" total_loss / length(train_loader)

        total_acc = 0.0
        st_ = Lux.testmode(tstate.states)
        ProgressMeter.@showprogress for (x, y) in test_loader
            ŷ, st_ = model_compiled(x, tstate.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            acc = accuracy(ŷ, y)
            total_acc += acc
        end
        @info "Epoch $epoch - Accuracy: $(total_acc / length(test_loader))"
    end
end

function accuracy(y_pred, y_true)
    y_pred = Lux.softmax(y_pred, dims=1)
    y_pred = map(x -> x[1], argmax(y_pred, dims=1))
    y_true = map(x -> x[1], argmax(y_true, dims=1))
    correct = sum(y_pred .== y_true)
    total = size(y_true, 2)
    return sum(correct) / total
end

These are the results for ViT Base:

PyTorch: 5.1 GB VRAM - 23 seconds / epoch

Flux: 16.2 GB VRAM - 38 seconds / epoch

Lux: Unknown GB VRAM - 25 seconds / epoch

I couldn’t determine the VRAM usage for LUX/Reactant due to pre-allocation. However, I was able to increase the batch size to a maximum of 128, which is significantly higher than under Flux (24) and equivalent to PyTorch at FP32 precision.

From these, it seems that the main issue is how Flux/Zygote allocates intermediate arrays, which results in much higher memory requirements than Lux/Reactant. If this is the direction that the Julia community is moving, I’ll probably modify my project to use Lux instead of Flux.

That being said it’s a very promising direction for ML in Julia. Some very cool stuff is possible like being able to directly load PyTorch models via StableHLO: PyTorch StableHLO Support · Issue #2065 · EnzymeAD/Reactant.jl · GitHub

Does StableHLO produce native Julia models, where you can extract and modify layers, or is it essentially a black box like ONNX? The reason I ask is that my current project involves implementing Timm models as pure Julia equivalents in Flux, then I define a load_params! method that takes a PyTorch state_dict from the matching Timm model/layer and loads the parameters into the corresponding Flux layer. This has the advantage of producing a model that can be used like any other Flux layer, but obviously requires a fair amount of work to duplicate the original PyTorch code.

It doesn’t produce native Julia models, but you can use it inside the rest of your native Julia model. Well, I suppose as long as you compile it to Reactant.jl.

Also see here, I’m working on a timm port for Lux.jl: [ANN] Jimm.jl: Lux ports of timm image backbones, with HuggingFace pretrained weights

It would likely be pretty easy to do the same sort of workflow I did here to port things over to Flux.jl, but I’m not personally familiar with it.

Also see here, I’m working on a timm port for Lux.jl: [ANN] Jimm.jl: Lux ports of timm image backbones, with HuggingFace pretrained weights

That’s almost exactly what I’m working on (my working name was even Jimm). I’ll take a look and see if I can contribute. So far, I’ve implemented all variants of Timm’s VisionTransformer, ConvNeXt (both v1 and v2), and Eva (basically ViT with rotary positional embeddings used by SAM3). I also have implementations for Swin, PVT, and Twins, but I didn’t get around to adding pre-trained weights yet. It should be relatively straightforward to convert from Flux to Lux.

We should be able to write a stablehlo->native julia arrays in Reactant/MLIR [and have been meaning to, but it’s not currently high priority – if anyone wants to give it a go, please reach out and I’d be happy to help you get started!]

That would be great! I don’t have any ViT added yet. ConvNext V1 and V2 are done along with BiT ResNetV2. I’m not sure all the differences between Flux.jl and Lux.jl, but I went with Lux.jl because I wanted everything to work nicely with the SciML ecosystem.

Having any modern pretrained backbones is already a huge roadblock removed for people working with computer vision problems in Julia.

Having any modern pretrained backbones is already a huge roadblock removed for people working with computer vision problems in Julia.

This was exactly my thinking. My research involves land cover classification, and not being able to access state-of-the-art vision models in Julia has been a huge issue. As a result of my project, I successfully fine-tuned vit_pe_spatial_base_patch16_512.fb, the same vision encoder used in Meta’s SAM 3, achieving SOTA metrics across several benchmarks. However, high memory usage severely limited my ability to train larger models and use larger batch sizes, which prompted this topic. Since Lux seems to resolve this issue, I’m happy to shift my focus to Jimm.

It seems we were literally thinking the exact same thing at the exact same time. I mostly made the announcement for the package to find collaborators; whenever you are ready I can get you full access to the repo.