Flux: How to minimise the garbage collection time?

I have a code where I use several different separate neural networks inside a for loop. Here is the times that it tooks to go over the loop:

0.149468 seconds (10.69 k allocations: 1.331 MiB, 99.45% gc time)
0.001026 seconds (10.55 k allocations: 1.318 MiB)
0.000666 seconds (9.94 k allocations: 1.144 MiB)
0.000733 seconds (10.61 k allocations: 1.322 MiB)
0.150424 seconds (11.14 k allocations: 1.491 MiB, 99.48% gc time)
0.000966 seconds (10.65 k allocations: 1.326 MiB)
0.000790 seconds (10.84 k allocations: 1.464 MiB)
0.000808 seconds (11.34 k allocations: 1.508 MiB)
0.149979 seconds (10.99 k allocations: 1.358 MiB, 99.48% gc time)
0.000899 seconds (10.24 k allocations: 1.231 MiB)
0.000700 seconds (10.32 k allocations: 1.237 MiB)
0.000648 seconds (10.08 k allocations: 1.155 MiB)
0.149194 seconds (10.79 k allocations: 1.397 MiB, 99.49% gc time)

As you can see, the garbage collection takes is 3 orders of magnitude longer than to actually execute the code. I profiled the code and %95 percent of the allocations are done by the Flux internally. This means that, to improve the execution time of my code, I have to abandon Flux, which I do not want to.

Is there a way to minimise allocations that Flux does internally?

1 Like

Can you give a MWE of this? That should absolutely never happen.


Unfortunately, at the moment, I cannot, because the code is a part of a research project that is yet to be published.
However, one should be able to produce it without much trouble.

  • Create 200 neural networks, each having ~5000 parameters.
  • Use MNIST dataset to train the the networks in a way that in each step, 1) sample the training data, 2) update the networks).
  • Loop over the last step

That sounds like a job for SimpleChains.jl

But that is not scalable - we want to increase the number of parameters by several orders of magnitude in the near future.

Then it won’t be GC time limited :sweat_smile:. Optimizing small neural networks has very different properties and is not a good indicator of how larger neural networks will act.


but the larger the network, the more time it takes for GC, so the bottleneck will still be there…

That scales linearly while the compute scales cubically.

The current difference between the GC time and compute time is 3 orders of magnitude, so even if the GC scales slower, within the boundaries of computing capabilities, GC time stays an important factor in the execution time of the code.


  • currently gc work is 1000x compute work
  • gc work scales linearly and compute work scales cubically

then your breakeven point will be at a 10√10-fold (30x) increase in number of parameters. past this point compute will dominate gc.

Even if the implementation could be simple, this is a lot of work to ask someone to do. You will probably get more useful help if you supplied this code yourself, plus your code would be more like the real implementation and thus a better MWE than someone else’. There are probably a ton of choices that could be made that would not reflect your real code.


Fair point.
Then I will supply tomorrow

1 Like

The fact that compute time dominates is not an argument not to minimise the GC time …

Of course not, I agree with that.

What I meant was more along the lines to attempt to highlight the fact that there are two implementations which capture different performance tradeoffs. If your model is small and dominated by GC time then you can minimize that by using SimpleChains. If you make your model 30x bigger you might get more performance out of Flux. That being said, it’s probably also very possible to reduce the GC time that Flux uses.

Are the NNs structurally similar with different parameters? If so, using Lux and just swapping but reusing the same parameter memory would be a good fix


In terms of their structural design, they are exactly the same. I will take a look at Luz; thanks!

@avikpal is there a recursive copyto! operation that you have implemented for doing such a copy on nested named tuples?

IIUC, the neural network structure is the same, but the parameters are different. If so I am not sure how copyto! helps because we need to retain all the parameters in memory, right?

Regarding copyto!, the easiest way is to use ComponentArrays:

ps1 = ComponentArray(ps1)
ps2 = ComponentArray(ps2)

copyto!(getdata(ps1), getdata(ps2))

Though it would be handy to implement something like that for nested named tuples.

Here is a MVE that I created on top of the following example: model-zoo/vision/conv_mnist at master · FluxML/model-zoo · GitHub

exec julia --optimize=3 --threads=6  "${BASH_SOURCE[0]}" "$@"

using Flux
using Flux.Data: DataLoader
using Flux.Optimise: Optimiser, WeightDecay
using Flux: onehotbatch, onecold, flatten
using Flux.Losses: logitcrossentropy
using Statistics, Random
using Logging: with_logger
using ProgressMeter: @showprogress
import MLDatasets
import BSON
using CUDA

# We set default values for the arguments for the function `train`:

Base.@kwdef mutable struct Args
    η = 3e-4             ## learning rate
    λ = 0                ## L2 regularizer param, implemented as weight decay
    batchsize = 128      ## batch size
    epochs = 50          ## number of epochs
    seed = 0             ## set seed > 0 for reproducibility
    use_cuda = true      ## if true use cuda (if available)
    infotime = 1      ## report every `infotime` epochs
    checktime = 5        ## Save the model every `checktime` epochs. Set to 0 for no checkpoints.
    savepath = "runs/"   ## results path

# ## Data

# We create the function `get_data` to load the MNIST train and test data from [MLDatasets](https://github.com/JuliaML/MLDatasets.jl) and reshape them so that they are in the shape that Flux expects.

function get_data(args)
    xtrain, ytrain = MLDatasets.MNIST(:train)[:]
    xtest, ytest = MLDatasets.MNIST(:test)[:]

    xtrain = reshape(xtrain, 28, 28, 1, :)
    xtest = reshape(xtest, 28, 28, 1, :)

    ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

    train_loader = DataLoader((xtrain, ytrain), batchsize=args.batchsize, shuffle=true)
    test_loader = DataLoader((xtest, ytest), batchsize=args.batchsize)

    return train_loader, test_loader

# The function `get_data` performs the following tasks:

# * **Loads MNIST dataset:** Loads the train and test set tensors. The shape of the train data is `28x28x60000` and the test data is `28x28x10000`.
# * **Reshapes the train and test data:**  Notice that we reshape the data so that we can pass it as arguments for the input layer of the model.
# * **One-hot encodes the train and test labels:** Creates a batch of one-hot vectors so we can pass the labels of the data as arguments for the loss function. For this example, we use the [logitcrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitcrossentropy) function and it expects data to be one-hot encoded.
# * **Creates mini-batches of data:** Creates two DataLoader objects (train and test) that handle data mini-batches of size `128 ` (as defined above). We create these two objects so that we can pass the entire data set through the loss function at once when training our model. Also, it shuffles the data points during each iteration (`shuffle=true`).

# ## Model

# We create the LeNet5 "constructor". It uses Flux's built-in [Convolutional and pooling layers](https://fluxml.ai/Flux.jl/stable/models/layers/#Convolution-and-Pooling-Layers):

function LeNet5(; imgsize=(28, 28, 1), nclasses=10)
    out_conv_size = (imgsize[1] ÷ 4 - 3, imgsize[2] ÷ 4 - 3, 16)

    return [Chain(
        Conv((5, 5), imgsize[end] => 6, relu),
        MaxPool((2, 2)),
        Conv((5, 5), 6 => 16, relu),
        MaxPool((2, 2)),
        Dense(prod(out_conv_size), 120, relu),
        Dense(120, 84, relu),
        Dense(84, nclasses)
    ) for i in 1:5]


# ## Loss function

# We use the function [logitcrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitcrossentropy) to compute the difference between
# the predicted and actual values (loss).

loss(ŷ, y) = logitcrossentropy(ŷ, y)

# Also, we create the function `eval_loss_accuracy` to output the loss and the accuracy during training:

function eval_loss_accuracy(loader, model, device)
    l = 0.0f0
    acc = 0
    ntot = 0
    for (x, y) in loader
        x, y = x |> device, y |> device
        ŷ = model(x)
        l += loss(ŷ, y) * size(x)[end]
        acc += sum(onecold(ŷ |> cpu) .== onecold(y |> cpu))
        ntot += size(x)[end]
    return (loss=l / ntot |> round4, acc=acc / ntot * 100 |> round4)

# ## Utility functions
# We need a couple of functions to obtain the total number of the model's parameters. Also, we create a function to round numbers to four digits.

num_params(model) = sum(length, Flux.params(model))
round4(x) = round(x, digits=4)

# ## Train the model

# Finally, we define the function `train` that calls the functions defined above to train the model.

function train(; kws...)
    args = Args(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    use_cuda = args.use_cuda && CUDA.functional()

    if use_cuda
        device = gpu
        @info "Training on GPU"
        device = cpu
        @info "Training on CPU"

    ## DATA
    train_loader, test_loader = get_data(args)

    model = LeNet5() |> device
    @info "LeNet5 model: $(num_params(model)) trainable params"

    ps = Flux.params.(model)

    opt = [ADAM(args.η) for i in 1:5]
    if args.λ > 0 ## add weight decay, equivalent to L2 regularization
        opt = Optimiser(WeightDecay(args.λ), opt)

    function report(epoch, model)
        train = eval_loss_accuracy(train_loader, model, device)
        test = eval_loss_accuracy(test_loader, model, device)
        println("Epoch: $epoch   Train: $(train)   Test: $(test)")

    @info "Start Training"
    for epoch in 1:args.epochs
        @time begin
            Threads.@threads for p_i in 1:5
                for (x, y) in train_loader
                    x, y = x |> device, y |> device
                    gs = Flux.gradient(ps[p_i]) do
                        ŷ = model[p_i](x)
                        loss(ŷ, y)

                    Flux.Optimise.update!(opt[p_i], ps[p_i], gs)

# The function `train` performs the following tasks:

# * Checks whether there is a GPU available and uses it for training the model. Otherwise, it uses the CPU.
# * Loads the MNIST data using the function `get_data`.
# * Creates the model and uses the [ADAM optimiser](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.ADAM) with weight decay.
# * Loads the [TensorBoardLogger.jl](https://github.com/JuliaLogging/TensorBoardLogger.jl) for logging data to Tensorboard.
# * Creates the function `report` for computing the loss and accuracy during the training loop. It outputs these values to the TensorBoardLogger.
# * Runs the training loop using [Flux’s training routine](https://fluxml.ai/Flux.jl/stable/training/training/#Training). For each epoch (step), it executes the following:
#   * Computes the model’s predictions.
#   * Computes the loss.
#   * Updates the model’s parameters.
#   * Saves the model `model.bson` every `checktime` epochs (defined as argument above.)

# ## Run the example

# We call the  function `train`:

if abspath(PROGRAM_FILE) == @__FILE__

and here is the output

conv_mnist % ./conv_mnist.jl
[ Info: Training on CPU
[ Info: LeNet5 model: 222130 trainable params
[ Info: Start Training
 26.135067 seconds (62.60 M allocations: 37.208 GiB, 7.10% gc time, 63.96% compilation time)
  9.269353 seconds (1.46 M allocations: 34.171 GiB, 14.60% gc time)
  9.071051 seconds (1.46 M allocations: 34.171 GiB, 13.45% gc time)
  8.948899 seconds (1.46 M allocations: 34.171 GiB, 12.56% gc time)
  9.114095 seconds (1.46 M allocations: 34.171 GiB, 13.14% gc time)
  9.154796 seconds (1.46 M allocations: 34.171 GiB, 12.78% gc time)
  9.305023 seconds (1.46 M allocations: 34.171 GiB, 13.18% gc time)
  9.144356 seconds (1.46 M allocations: 34.171 GiB, 12.49% gc time)
  9.043380 seconds (1.46 M allocations: 34.171 GiB, 11.82% gc time)
  9.010278 seconds (1.46 M allocations: 34.171 GiB, 11.65% gc time)
  8.917086 seconds (1.46 M allocations: 34.171 GiB, 11.01% gc time)
  9.225805 seconds (1.46 M allocations: 34.171 GiB, 11.22% gc time)
  8.827871 seconds (1.46 M allocations: 34.171 GiB, 10.94% gc time)
  9.027693 seconds (1.46 M allocations: 34.171 GiB, 11.73% gc time)
  8.843056 seconds (1.46 M allocations: 34.171 GiB, 10.61% gc time)
  8.870301 seconds (1.46 M allocations: 34.171 GiB, 11.07% gc time)
  8.815562 seconds (1.46 M allocations: 34.171 GiB, 10.54% gc time)
  9.067778 seconds (1.46 M allocations: 34.171 GiB, 10.77% gc time)
  8.936289 seconds (1.46 M allocations: 34.171 GiB, 10.84% gc time)
  8.958682 seconds (1.46 M allocations: 34.171 GiB, 10.88% gc time)
  8.870017 seconds (1.46 M allocations: 34.171 GiB, 9.84% gc time)
  8.858053 seconds (1.46 M allocations: 34.171 GiB, 10.49% gc time)
  9.278745 seconds (1.46 M allocations: 34.171 GiB, 11.46% gc time)
  8.758612 seconds (1.46 M allocations: 34.171 GiB, 10.07% gc time)

that doesn’t reproduce. gc time there is only like 10%