Hi,
I am training a CNN on CIFAR10 (via Flux.DataLoader) and am trying to apply some stochastic data transformations to each batch of data at runtime.
My goal is to apply the following operations to each batch I get from the dataloader.
- pad with zeros (yielding 32x32 → 40x40 images).
- Random crop (yielding 40x40 → 32x32 images).
- Flip images horizontally with 50% probability.
Is there a best practice for achieving this in Julia?
Augmentor.jl seems like it might be the way to go, but it requires me to convert between dataformats, and does not seem to have an option for padding.
I found an example on the dev version of Augmentor.jl’s docs (link), which relies on MappedArrays instead of Flux.DataLoader. This seems less readable and as far as I can tell data is not shuffled at every epoch.
Using Julia and Flux has generally been a breeze so far, with custom layers and learning rules being very simple to implement. So I was quite surprised that implementing standard data augmentation seems to take much more effort.
I have tried to implement a minimal working example shown below. The interesting parts are probably the function MWE and getdata. I did not find a good way to implement padding (and to random crop to size 32x32 I need padding), so the only transformation applied at the moment is FlipX(0.5).
I guess my questions are:
- Am I on the right track with using Augmentor.jl or are there better options?
- If Augmentor.jl is the way to go, then how could I implement the padding and random cropping?
- Do you have general ideas on how to make things cleaner/faster? For larger networks my current approach slows things down a bit.
using Augmentor, MLDatasets
using Flux, Flux.Optimise
using Flux: onehotbatch, onecold
using Flux.Losses: logitcrossentropy
function getdata(batchsize)
xtrain, ytrain = MLDatasets.CIFAR10.traindata(Float32)
xtest, ytest = MLDatasets.CIFAR10.testdata(Float32)
m = reshape([0.4914009f0 0.4914009f0 0.4465309f0], (1,1,3,1))
s = reshape([0.20230277f0 0.19941312f0 0.2009607f0], (1,1,3,1))
xtrain = (xtrain .- m) ./ s
xtest = (xtest .- m) ./ s
# Convert training data to RGB to work with augmentbatch!()
xtrain = MLDatasets.CIFAR10._colorview(RGB, permutedims(xtrain, (3, 1, 2, 4)))
ytrain, ytest = Flux.onehotbatch(ytrain, 0:9), Flux.onehotbatch(ytest, 0:9)
trainloader = Flux.DataLoader((xtrain, ytrain), batchsize=batchsize, shuffle=true, partial=false)
testloader = Flux.DataLoader((xtest, ytest), batchsize=batchsize, partial=false)
return (trainloader, testloader)
end
function LeNet5(; imgsize=(28,28,1), nclasses=10)
out_conv_size = (imgsize[1]÷4 - 3, imgsize[2]÷4 - 3, 16)
return Chain(
Conv((5, 5), imgsize[end]=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6=>16, relu),
MaxPool((2, 2)),
flatten,
Dense(prod(out_conv_size), 120, relu),
Dense(120, 84, relu),
Dense(84, nclasses)
)
end
function loss_and_accuracy(data_loader, net, device)
acc = 0.0f0; ls = 0.0f0; num = 0
for (x, y) in data_loader
x, y = x |> device, y |> device
pred = net(x)
ls += logitcrossentropy(pred, y)
acc += sum(onecold(cpu(pred)) .== onecold(cpu(y)))
num += size(y, 2)
end
return ls / num, acc / num
end
function MWE()
pl = FlipX(0.5) |> SplitChannels() |> PermuteDims((2, 3, 1))
device = gpu
batchsize = 128
trainloader, testloader = getdata(batchsize)
opt = ADAM(0.0001)
net = LeNet5(imgsize=(32, 32, 3), nclasses=10)
net = net |> device
ps = Flux.params(net)
for epoch=1:5
for (x, y) in trainloader
xaug = zeros(Float32, 32, 32, 3, batchsize)
augmentbatch!(xaug, x, pl)
xaug, y = xaug |> device, y |> device
gs = gradient(ps) do
l = logitcrossentropy(net(xaug), y)
end
update!(opt, ps, gs)
end
test_loss, test_acc = loss_and_accuracy(testloader, net, device)
@info """Epoch: $epoch:
Test: Acc(θ): $(round(test_acc*100f0, digits=2))% Loss: $(round(test_loss, digits=6))
"""
end
end
MWE()