Hello,
First let me preface this by saying I am new to Julia and machine learning. I just wanted to dip my toes into machine learning for fun, started out with OCR using an MLP first and now I’m trying to implement a super resolution neural network (SRCNN), with a lot of help from ChatGPT, the original paper and other resources.
This is all of my code, which I think is already the MWE for my problem.
using Flux
using Images
using JLD2
# Create super resolution convolutional neural network model
function srcnn()
model = Chain(
# 3 Channel input for RGB
Conv((9, 9), 3 => 64, pad=(2, 2), relu),
Conv((1, 1), 64 => 32, pad=(2, 2), relu),
Conv((5, 5), 32 => 3, pad=(2, 2))
)
return model
end
# Return tuples of (low_res, high_res) images
function preprocess_image(image_path, scale_factor=4)
high_res = load(image_path)
w, h = size(high_res) ./ scale_factor
w = Int64(floor(w))
h = Int64(floor(h))
low_res = imresize(high_res, (w, h))
upscale = imresize(low_res, size(high_res))
upscale_float = Flux.unsqueeze(float32.(permutedims(channelview(upscale), (2, 3, 1))) ./ 255, dims=4)
high_res_float = Flux.unsqueeze(float32.(permutedims(channelview(high_res), (2, 3, 1))) ./ 255, dims=4)
return (upscale_float, high_res_float)
end
# Return vector of tuples of (low_res, high_res) training data
function iterate_over_images(train_dir)
file_paths = readdir(train_dir, join=true)
training_data = []
for path in file_paths
push!(training_data, preprocess_image(path))
end
return training_data
end
function main()
model = srcnn()
loss(x, y) = Flux.mse(x, y)
opt = Flux.setup(Adam(), model)
@info "Reading in images"
training_dir = "./train/"
training_data = iterate_over_images(training_dir)
@info "Starting the training loop"
for epoch in 1:5
for i in 1:length(training_data)
x = training_data[i][1]
y = training_data[i][2]
grads = Flux.gradient(model) do m
result = m(x)
loss(result, y)
end
Flux.update!(opt, model, grads[1])
GC.gc()
end
end
# JLD2.jldsave("small_model.jld2", model)
JLD2.@save "small_model.jld2" model
end
main()
Problem
I have a folder containing 3 high resolution images, all of different sizes, that I am using to train. The problem occurs when I include more than 1 image to train, I reach an OOM error and the process is killed by the OS (just before the update!
line of the second iteration I’m pretty sure).
I’m running this through WSL2 with 32GB of ram available (free -mh
) and no GPU passthrough.
And this may not be the place to ask, but if anyone could say this program is a correct implementation, or not, that would be very helpful.
Thanks.
Edit: I’m also running Julia version 1.10.3