Significantly Higher VRAM Usage and Slower Training on Flux Compared to PyTorch

I’ve been working on a project where I’m converting Timm models to Flux along with their pre-trained weights. Everything is going well so far, but I noticed that Flux has significantly higher VRAM consumption and takes about twice as long to train compared to PyTorch.

To confirm this, I wrote a simple PyTorch and Flux script that trains a ResNet-34 model on 8,000 samples from the AID dataset. To make sure the difference wasn’t down to my own mistakes, I used the ResNet implementation from Metalhead.

Here are the respective scripts:

PyTorch

class TimmClassifier(pl.LightningModule):

    def __init__(self, model:str, num_classes: int, learning_rate: float = 1e-3, pretrained: bool = True):
        super().__init__()
        self.save_hyperparameters()
        self.model = timm.create_model(
            model, pretrained=pretrained, num_classes=num_classes
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def _shared_step(self, batch, stage: str):
        images, labels = batch
        logits = self(images)
        loss = self.loss_fn(logits, labels)
        preds = logits.argmax(dim=1)
        acc = (preds == labels).float().mean()
        self.log(f"{stage}_loss", loss, prog_bar=True)
        self.log(f"{stage}_acc", acc, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        self._shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        self._shared_step(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return [optimizer], [scheduler]


def train_resnet34(
    dataset,
    num_classes: int,
    val_split: float = 0.2,
    batch_size: int = 16,
    max_epochs: int = 20,
    learning_rate: float = 1e-4,
    pretrained: bool = True,
    num_workers: int = 4,
    accelerator: str = "auto",
):
    # --- Split dataset ---
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True,
    )

    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True,
    )

    # --- Model ---
    model = TimmClassifier(
        model="resnet34",
        num_classes=num_classes,
        learning_rate=learning_rate,
        pretrained=pretrained,
    )

    # --- Callbacks ---
    checkpoint_cb = ModelCheckpoint(
        monitor="val_loss", mode="min", save_top_k=1, filename="best"
    )
    early_stop_cb = EarlyStopping(monitor="val_loss", patience=5, mode="min")

    # --- Trainer ---
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator=accelerator,
        callbacks=[checkpoint_cb, early_stop_cb],
        log_every_n_steps=10,
        precision="32-true", 
    )
    trainer.fit(model, train_loader, val_loader)


def main():
    dataset = AID(path="../../Data/AID")
    train_resnet34(dataset, num_classes=len(dataset.classes), pretrained=False, max_epochs=5)

Flux

struct FineTuneEncoder{E} <: FluxModule
    encoder::E
end

function FineTuneEncoder(num_classes::Int)
    encoder = Metalhead.ResNet(34, pretrain=false, nclasses=num_classes)
    return FineTuneEncoder(encoder)
end

Flux.Optimisers.trainable(x::FineTuneEncoder) = (; x.encoder)


(model::FineTuneEncoder)(x) = model.encoder(x)

function loss_and_accuracy(model::FineTuneEncoder, batch)
    x, y = batch
    ŷ = model(x)
    return Flux.logitcrossentropy(ŷ, y), Tsunami.accuracy(ŷ, y)
end

function Tsunami.train_step(model::FineTuneEncoder, trainer, batch)
    loss, acc = loss_and_accuracy(model, batch)
    Tsunami.log(trainer, "loss/train", loss, prog_bar=true)
    Tsunami.log(trainer, "accuracy/train", acc, prog_bar=true)
    return loss
end

function Tsunami.val_step(model::FineTuneEncoder, trainer, batch)
    loss, acc = loss_and_accuracy(model, batch)
    Tsunami.log(trainer, "loss/val", loss)
    Tsunami.log(trainer, "accuracy/val", acc)
end

function Tsunami.configure_optimisers(m::FineTuneEncoder, trainer)
    opt_rule = Flux.Optimisers.Adam(1e-4)
    opt_state = Flux.Optimisers.setup(opt_rule, m)
    return opt_state
end

function run_train()
    # Prepare the dataset
    dataset = ImageDataset("../../Data/AID", imsize=(224, 224))
    train_data, test_data = Flux.MLUtils.splitobs(dataset, at=0.8, shuffle=true)
    train_loader = Flux.DataLoader(train_data, batchsize=16, collate=true, parallel=true, shuffle=true)
    test_loader = Flux.DataLoader(test_data, batchsize=16, collate=true, parallel=true)

    # Create and train the model
    model = FineTuneEncoder(length(dataset.labels))
    trainer = Trainer(max_epochs=5)
    Tsunami.fit!(model, trainer, train_loader, test_loader)
end

These are the results for ResNet-34 on a machine with 64 GB of RAM and an NVIDIA RTX 5090 with 32 GB of VRAM. VRAM usage is reported by nvtop.

Flux: 19.1 GB VRAM - 11 seconds / epoch

PyTorch: 2.6 GB VRAM - 7 seconds / epoch

I also tried another experiment with a ViT base model, which produced the following.

Flux: 16.2 GB VRAM - 38 seconds / epoch

PyTorch: 5.1 GB VRAM - 23 seconds / epoch

To determine if Julia was perhaps just allocating more buffer space without actually using it, I tried to increase the batch size until I got an OOM. PyTorch was able to handle a batch size of 128 with ViT base, while Flux could only go up to 24 after throwing OOM errors at 32.

Does anyone have any idea what might be causing this discrepancy?

1 Like