I am training a CNN on CIFAR10 (via Flux.DataLoader) and am trying to apply some stochastic data transformations to each batch of data at runtime.
My goal is to apply the following operations to each batch I get from the dataloader.
- pad with zeros (yielding 32x32 → 40x40 images).
- Random crop (yielding 40x40 → 32x32 images).
- Flip images horizontally with 50% probability.
Is there a best practice for achieving this in Julia?
Augmentor.jl seems like it might be the way to go, but it requires me to convert between dataformats, and does not seem to have an option for padding.
I found an example on the dev version of Augmentor.jl’s docs (link), which relies on MappedArrays instead of Flux.DataLoader. This seems less readable and as far as I can tell data is not shuffled at every epoch.
Using Julia and Flux has generally been a breeze so far, with custom layers and learning rules being very simple to implement. So I was quite surprised that implementing standard data augmentation seems to take much more effort.
I have tried to implement a minimal working example shown below. The interesting parts are probably the function MWE and getdata. I did not find a good way to implement padding (and to random crop to size 32x32 I need padding), so the only transformation applied at the moment is FlipX(0.5).
I guess my questions are:
- Am I on the right track with using Augmentor.jl or are there better options?
- If Augmentor.jl is the way to go, then how could I implement the padding and random cropping?
- Do you have general ideas on how to make things cleaner/faster? For larger networks my current approach slows things down a bit.
using Augmentor, MLDatasets using Flux, Flux.Optimise using Flux: onehotbatch, onecold using Flux.Losses: logitcrossentropy function getdata(batchsize) xtrain, ytrain = MLDatasets.CIFAR10.traindata(Float32) xtest, ytest = MLDatasets.CIFAR10.testdata(Float32) m = reshape([0.4914009f0 0.4914009f0 0.4465309f0], (1,1,3,1)) s = reshape([0.20230277f0 0.19941312f0 0.2009607f0], (1,1,3,1)) xtrain = (xtrain .- m) ./ s xtest = (xtest .- m) ./ s # Convert training data to RGB to work with augmentbatch!() xtrain = MLDatasets.CIFAR10._colorview(RGB, permutedims(xtrain, (3, 1, 2, 4))) ytrain, ytest = Flux.onehotbatch(ytrain, 0:9), Flux.onehotbatch(ytest, 0:9) trainloader = Flux.DataLoader((xtrain, ytrain), batchsize=batchsize, shuffle=true, partial=false) testloader = Flux.DataLoader((xtest, ytest), batchsize=batchsize, partial=false) return (trainloader, testloader) end function LeNet5(; imgsize=(28,28,1), nclasses=10) out_conv_size = (imgsize÷4 - 3, imgsize÷4 - 3, 16) return Chain( Conv((5, 5), imgsize[end]=>6, relu), MaxPool((2, 2)), Conv((5, 5), 6=>16, relu), MaxPool((2, 2)), flatten, Dense(prod(out_conv_size), 120, relu), Dense(120, 84, relu), Dense(84, nclasses) ) end function loss_and_accuracy(data_loader, net, device) acc = 0.0f0; ls = 0.0f0; num = 0 for (x, y) in data_loader x, y = x |> device, y |> device pred = net(x) ls += logitcrossentropy(pred, y) acc += sum(onecold(cpu(pred)) .== onecold(cpu(y))) num += size(y, 2) end return ls / num, acc / num end function MWE() pl = FlipX(0.5) |> SplitChannels() |> PermuteDims((2, 3, 1)) device = gpu batchsize = 128 trainloader, testloader = getdata(batchsize) opt = ADAM(0.0001) net = LeNet5(imgsize=(32, 32, 3), nclasses=10) net = net |> device ps = Flux.params(net) for epoch=1:5 for (x, y) in trainloader xaug = zeros(Float32, 32, 32, 3, batchsize) augmentbatch!(xaug, x, pl) xaug, y = xaug |> device, y |> device gs = gradient(ps) do l = logitcrossentropy(net(xaug), y) end update!(opt, ps, gs) end test_loss, test_acc = loss_and_accuracy(testloader, net, device) @info """Epoch: $epoch: Test: Acc(θ): $(round(test_acc*100f0, digits=2))% Loss: $(round(test_loss, digits=6)) """ end end MWE()