Flux memory usage high in SRCNN


First let me preface this by saying I am new to Julia and machine learning. I just wanted to dip my toes into machine learning for fun, started out with OCR using an MLP first and now I’m trying to implement a super resolution neural network (SRCNN), with a lot of help from ChatGPT, the original paper and other resources.

This is all of my code, which I think is already the MWE for my problem.

using Flux
using Images
using JLD2

# Create super resolution convolutional neural network model
function srcnn()
    model = Chain(
        # 3 Channel input for RGB
        Conv((9, 9), 3 => 64, pad=(2, 2), relu),
        Conv((1, 1), 64 => 32, pad=(2, 2), relu),
        Conv((5, 5), 32 => 3, pad=(2, 2))

    return model

# Return tuples of (low_res, high_res) images
function preprocess_image(image_path, scale_factor=4)
    high_res = load(image_path)
    w, h = size(high_res) ./ scale_factor
    w = Int64(floor(w))
    h = Int64(floor(h))

    low_res = imresize(high_res, (w, h))
    upscale = imresize(low_res, size(high_res))

    upscale_float = Flux.unsqueeze(float32.(permutedims(channelview(upscale), (2, 3, 1))) ./ 255, dims=4)
    high_res_float = Flux.unsqueeze(float32.(permutedims(channelview(high_res), (2, 3, 1))) ./ 255, dims=4)
    return (upscale_float, high_res_float)

# Return vector of tuples of (low_res, high_res) training data
function iterate_over_images(train_dir)
    file_paths = readdir(train_dir, join=true)
    training_data = []
    for path in file_paths
        push!(training_data, preprocess_image(path))

    return training_data

function main()
    model = srcnn()
    loss(x, y) = Flux.mse(x, y)

    opt = Flux.setup(Adam(), model)

    @info "Reading in images"
    training_dir = "./train/"
    training_data = iterate_over_images(training_dir)

    @info "Starting the training loop"
    for epoch in 1:5
        for i in 1:length(training_data) 
            x = training_data[i][1]
            y = training_data[i][2]
            grads = Flux.gradient(model) do m
                result = m(x)
                loss(result, y)

            Flux.update!(opt, model, grads[1])

    # JLD2.jldsave("small_model.jld2", model)
    JLD2.@save "small_model.jld2" model


I have a folder containing 3 high resolution images, all of different sizes, that I am using to train. The problem occurs when I include more than 1 image to train, I reach an OOM error and the process is killed by the OS (just before the update! line of the second iteration I’m pretty sure).

I’m running this through WSL2 with 32GB of ram available (free -mh) and no GPU passthrough.

And this may not be the place to ask, but if anyone could say this program is a correct implementation, or not, that would be very helpful.


Edit: I’m also running Julia version 1.10.3

I don’t see obvious mistakes.

I wonder type-stability is hurting you. Flux.gradient closes over x which comes from an Any container… something like this:

julia> out = []; push!(out, [1 2 3], [4 5 6]); out
2-element Vector{Any}:
 [1 2 3]
 [4 5 6]

julia> map(identity, out)
2-element Vector{Matrix{Int64}}:
 [1 2 3]
 [4 5 6]

That could be fixed by inserting map(identity, training_data) ? i.e. replacing this bit

with something like this:

    return map(identity, training_data)

function main(training_data = iterate_over_images("./train/"))
    model = srcnn()
    loss(x, y) = Flux.mse(x, y)
    opt = Flux.setup(Adam(), model)
    @info "Starting the training loop"
    for epoch in 1:5

Edit, on reading again:

Maybe my explanation is less likely. Only 3 images, how large are they? Conv((9, 9), 3 => 64, pad=(2, 2), relu) has to allocate something about 21 times as big as the image. Maybe this is simply too large, and needs stride=5 or a larger scale_factor?

The map(identity, training_data) tip sounds like it will help with run time if I were to train on many more images.

They are pretty high quality images from Pexel, around 1 MB each. So the first Conv layer should allocate around 21 MB? Then the second layer, around 10 MB, and third layer around 3 MB? Not sure if my thought process is correct, but I will certainly try a different stride and scale_factor when I get home. Thanks!

Unfortunately, couldn’t get the stride working, because the output image dimensions change.

Calculated the memory the first layer would allocate too. Given an image of size 4000 x 2667 x 3 (and float32’s), according to @allocated output it is 384048128 bytes, after applying the first layer around 21 allocations of approximately the same size 384048128 * 21 = 8065010688 bytes = about 8GB. So looks like I’m just using images that are too large.

Revisited the paper and some other resources too, and realised they suggest just working on the Y channel of YCrCb colour space is sufficient. So I’ll have to change that part too

1 Like