Loss/Accuracy computation in validation set in Flux

Hi I have written the code below following the examples in Flux model zoo (conv_mnist). I have a feeling I am doing something wrong especially in the validation set. I have some experience with pytorch and usually when you compute the loss/accuracy in the validation set in PyTorch you would use with torch.no_grad() to prevent gradients from back propagating.

# How its done in Pytorch
for X_train, y_train in train_loader:
    yhat_train = model(X_train)
    loss = loss_func(yhat_train, y_train)
    # ...

with torch.no_grad():
    for X_test, y_test in test_loader:
        yhat_test = model(X_test)
        # ...

is there something similar in flux or is this implementation below correct?

# Flux
    losses = Dict(:train => [], :test => [])
    accuracy = Dict(:train => [], :test => [])
    train_batch_total = 0
    train_loss = 0f0
    train_acc = 0

    test_batch_total = 0
    test_loss = 0f0 
    test_acc  = 0

    for epoch in 1:args.epochs
        # Training 
        @showprogress for(x,y) in train_loader
            x, y = device(x), device(y) #device = Flux.cpu
            batchsize = size(x)[end]
            gs = gradient(params(model)) do 
                ŷ = model(x)

            # Just for tracking losses and accuracy
            ŷ = model(x)
            train_loss += logitcrossentropy(ŷ,y) * batchsize
            train_acc += (onecold(cpu(ŷ)) .== onecold(cpu(y))) |> sum 
            train_batch_total += batchsize

            # Update Optimizer
        push!(losses[:train],round(train_loss/train_batch_total, digits = 4))
        push!(accuracy[:train],round(train_acc/train_batch_total * 100, digits = 4))
        println("train loss : $(losses[:train][end]) train acc : $(accuracy[:train][end])")

        # Testing
        @showprogress for(x,y) in test_loader
            x,y = device(x), device(y)
            batchsize = size(x)[end]
            ŷ = model(x)
            test_loss += logitcrossentropy(ŷ,y) * batchsize 
            test_acc += (onecold(cpu(ŷ)) .== onecold(cpu(y))) |> sum 
            test_batch_total += batchsize 
        push!(losses[:test],round(test_loss/test_batch_total, digits = 4))
        push!(accuracy[:test],round(test_acc/test_batch_total * 100, digits = 4))
        println("test loss : $(losses[:test][end]) test acc : $(accuracy[:test][end])")

    return losses, accuracy
out >>
[ Info: Training on CPU
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.5273 train acc : 85.3833
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.1704 test acc : 94.89
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.3378 train acc : 90.485
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.1392 test acc : 95.73
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.2607 train acc : 92.585
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.122 test acc : 96.17
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.2169 train acc : 93.7896
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.1081 test acc : 96.645
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1881 train acc : 94.5997
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.098 test acc : 96.962
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1677 train acc : 95.1728
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0909 test acc : 97.1867
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1519 train acc : 95.6181
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.085 test acc : 97.3857
Progress: 100%|█████████████████████████████| Time: 0:00:08
train loss : 0.1394 train acc : 95.9644
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0796 test acc : 97.5462
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.1292 train acc : 96.2535
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0762 test acc : 97.6478
Progress: 100%|█████████████████████████████| Time: 0:00:09
train loss : 0.1206 train acc : 96.4877
Progress: 100%|█████████████████████████████| Time: 0:00:00
test loss : 0.0727 test acc : 97.75

The losses on the train - test for each epoch looks quite off. All I can think is something is wrong with the way of computing accumulated losses/accuracy. Otherwise at testing time perhaps training is still happening. However it doesn’t appear to be this since optimizer.step() isn’t performed during evaluation stage so perhaps not. Either way I would like to know if this is the way to compute validation losses/accuracy.


I have skimmed the code and seems fine. Have you tried to substited train_loader for test_loader to see, if you do exactly the same computations?