Hi I have written the code below following the examples in Flux model zoo (conv_mnist). I have a feeling I am doing something wrong especially in the validation set. I have some experience with pytorch and usually when you compute the loss/accuracy in the validation set in PyTorch you would use with torch.no_grad()
to prevent gradients from back propagating.
# How its done in Pytorch
for X_train, y_train in train_loader:
yhat_train = model(X_train)
loss = loss_func(yhat_train, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ...
model.eval()
with torch.no_grad():
for X_test, y_test in test_loader:
yhat_test = model(X_test)
# ...
is there something similar in flux or is this implementation below correct?
# Flux
losses = Dict(:train => [], :test => [])
accuracy = Dict(:train => [], :test => [])
train_batch_total = 0
train_loss = 0f0
train_acc = 0
test_batch_total = 0
test_loss = 0f0
test_acc = 0
for epoch in 1:args.epochs
# Training
@showprogress for(x,y) in train_loader
x, y = device(x), device(y) #device = Flux.cpu
batchsize = size(x)[end]
gs = gradient(params(model)) do
ŷ = model(x)
logitcrossentropy(ŷ,y)
end
# Just for tracking losses and accuracy
ŷ = model(x)
train_loss += logitcrossentropy(ŷ,y) * batchsize
train_acc += (onecold(cpu(ŷ)) .== onecold(cpu(y))) |> sum
train_batch_total += batchsize
# Update Optimizer
Flux.Optimise.update!(opt,params(model),gs)
end
push!(losses[:train],round(train_loss/train_batch_total, digits = 4))
push!(accuracy[:train],round(train_acc/train_batch_total * 100, digits = 4))
println("train loss : $(losses[:train][end]) train acc : $(accuracy[:train][end])")
# Testing
@showprogress for(x,y) in test_loader
x,y = device(x), device(y)
batchsize = size(x)[end]
ŷ = model(x)
test_loss += logitcrossentropy(ŷ,y) * batchsize
test_acc += (onecold(cpu(ŷ)) .== onecold(cpu(y))) |> sum
test_batch_total += batchsize
end
push!(losses[:test],round(test_loss/test_batch_total, digits = 4))
push!(accuracy[:test],round(test_acc/test_batch_total * 100, digits = 4))
println("test loss : $(losses[:test][end]) test acc : $(accuracy[:test][end])")
end
return losses, accuracy
out >>
[ Info: Training on CPU
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.5273 train acc : 85.3833
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.1704 test acc : 94.89
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.3378 train acc : 90.485
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.1392 test acc : 95.73
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.2607 train acc : 92.585
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.122 test acc : 96.17
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.2169 train acc : 93.7896
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.1081 test acc : 96.645
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1881 train acc : 94.5997
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.098 test acc : 96.962
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1677 train acc : 95.1728
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0909 test acc : 97.1867
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1519 train acc : 95.6181
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.085 test acc : 97.3857
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1394 train acc : 95.9644
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0796 test acc : 97.5462
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.1292 train acc : 96.2535
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0762 test acc : 97.6478
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.1206 train acc : 96.4877
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0727 test acc : 97.75
The losses on the train - test for each epoch looks quite off. All I can think is something is wrong with the way of computing accumulated losses/accuracy. Otherwise at testing time perhaps training is still happening. However it doesn’t appear to be this since optimizer.step()
isn’t performed during evaluation stage so perhaps not. Either way I would like to know if this is the way to compute validation losses/accuracy.