I’ve been working on a project where I’m converting Timm models to Flux along with their pre-trained weights. Everything is going well so far, but I noticed that Flux has significantly higher VRAM consumption and takes about twice as long to train compared to PyTorch.
To confirm this, I wrote a simple PyTorch and Flux script that trains a ResNet-34 model on 8,000 samples from the AID dataset. To make sure the difference wasn’t down to my own mistakes, I used the ResNet implementation from Metalhead.
Here are the respective scripts:
PyTorch
class TimmClassifier(pl.LightningModule):
def __init__(self, model:str, num_classes: int, learning_rate: float = 1e-3, pretrained: bool = True):
super().__init__()
self.save_hyperparameters()
self.model = timm.create_model(
model, pretrained=pretrained, num_classes=num_classes
)
self.loss_fn = nn.CrossEntropyLoss()
self.learning_rate = learning_rate
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def _shared_step(self, batch, stage: str):
images, labels = batch
logits = self(images)
loss = self.loss_fn(logits, labels)
preds = logits.argmax(dim=1)
acc = (preds == labels).float().mean()
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
return loss
def training_step(self, batch, batch_idx):
return self._shared_step(batch, "train")
def validation_step(self, batch, batch_idx):
self._shared_step(batch, "val")
def test_step(self, batch, batch_idx):
self._shared_step(batch, "test")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
return [optimizer], [scheduler]
def train_resnet34(
dataset,
num_classes: int,
val_split: float = 0.2,
batch_size: int = 16,
max_epochs: int = 20,
learning_rate: float = 1e-4,
pretrained: bool = True,
num_workers: int = 4,
accelerator: str = "auto",
):
# --- Split dataset ---
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(
train_ds, batch_size=batch_size, shuffle=True,
num_workers=num_workers, pin_memory=True,
)
val_loader = DataLoader(
val_ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True,
)
# --- Model ---
model = TimmClassifier(
model="resnet34",
num_classes=num_classes,
learning_rate=learning_rate,
pretrained=pretrained,
)
# --- Callbacks ---
checkpoint_cb = ModelCheckpoint(
monitor="val_loss", mode="min", save_top_k=1, filename="best"
)
early_stop_cb = EarlyStopping(monitor="val_loss", patience=5, mode="min")
# --- Trainer ---
trainer = pl.Trainer(
max_epochs=max_epochs,
accelerator=accelerator,
callbacks=[checkpoint_cb, early_stop_cb],
log_every_n_steps=10,
precision="32-true",
)
trainer.fit(model, train_loader, val_loader)
def main():
dataset = AID(path="../../Data/AID")
train_resnet34(dataset, num_classes=len(dataset.classes), pretrained=False, max_epochs=5)
Flux
struct FineTuneEncoder{E} <: FluxModule
encoder::E
end
function FineTuneEncoder(num_classes::Int)
encoder = Metalhead.ResNet(34, pretrain=false, nclasses=num_classes)
return FineTuneEncoder(encoder)
end
Flux.Optimisers.trainable(x::FineTuneEncoder) = (; x.encoder)
(model::FineTuneEncoder)(x) = model.encoder(x)
function loss_and_accuracy(model::FineTuneEncoder, batch)
x, y = batch
ŷ = model(x)
return Flux.logitcrossentropy(ŷ, y), Tsunami.accuracy(ŷ, y)
end
function Tsunami.train_step(model::FineTuneEncoder, trainer, batch)
loss, acc = loss_and_accuracy(model, batch)
Tsunami.log(trainer, "loss/train", loss, prog_bar=true)
Tsunami.log(trainer, "accuracy/train", acc, prog_bar=true)
return loss
end
function Tsunami.val_step(model::FineTuneEncoder, trainer, batch)
loss, acc = loss_and_accuracy(model, batch)
Tsunami.log(trainer, "loss/val", loss)
Tsunami.log(trainer, "accuracy/val", acc)
end
function Tsunami.configure_optimisers(m::FineTuneEncoder, trainer)
opt_rule = Flux.Optimisers.Adam(1e-4)
opt_state = Flux.Optimisers.setup(opt_rule, m)
return opt_state
end
function run_train()
# Prepare the dataset
dataset = ImageDataset("../../Data/AID", imsize=(224, 224))
train_data, test_data = Flux.MLUtils.splitobs(dataset, at=0.8, shuffle=true)
train_loader = Flux.DataLoader(train_data, batchsize=16, collate=true, parallel=true, shuffle=true)
test_loader = Flux.DataLoader(test_data, batchsize=16, collate=true, parallel=true)
# Create and train the model
model = FineTuneEncoder(length(dataset.labels))
trainer = Trainer(max_epochs=5)
Tsunami.fit!(model, trainer, train_loader, test_loader)
end
These are the results for ResNet-34 on a machine with 64 GB of RAM and an NVIDIA RTX 5090 with 32 GB of VRAM. VRAM usage is reported by nvtop.
Flux: 19.1 GB VRAM - 11 seconds / epoch
PyTorch: 2.6 GB VRAM - 7 seconds / epoch
I also tried another experiment with a ViT base model, which produced the following.
Flux: 16.2 GB VRAM - 38 seconds / epoch
PyTorch: 5.1 GB VRAM - 23 seconds / epoch
To determine if Julia was perhaps just allocating more buffer space without actually using it, I tried to increase the batch size until I got an OOM. PyTorch was able to handle a batch size of 128 with ViT base, while Flux could only go up to 24 after throwing OOM errors at 32.
Does anyone have any idea what might be causing this discrepancy?