Why my NN overfitting?

Here is the Julia code, with 2 hidden layers, (784, 40) and (40, 10). Using relu as activation function.

using Flux
using MLDatasets
using ImageCore
using Distributions

p = Binomial(1, 0.5)

one_hot_encoding(x) = unique(x) .== permutedims(x)
train_size = 1000

train_x, train_y = MNIST.traindata(1:train_size);
test_x, test_y  = MNIST.testdata();

labels = one_hot_encoding(train_y)
test_labels = one_hot_encoding(test_y)

α, iterations, hidden_size, pixels_per_image, num_labels  = 0.005, 351, 40, 28 * 28, 10

weights₀ = 0.2 * rand(Float16, (pixels_per_image, hidden_size)) .- 0.1
weights₁ = 0.2 * rand(Float16, (hidden_size, num_labels)) .- 0.1

relu(x) = (x > 0) * x

relu2deriv(output) = output > 0

for epoch in 1:iterations
    error, correct = 0, 0
    for i in 1:train_size
        layer₀ = reshape(train_x[:, :, i], 1, 28*28)  # Input
        label = reshape(labels[:, i], 1, 10)

        layer₁ = relu.(layer₀ * weights₀)
        dropout_mask = rand(p, size(layer₁))
        layer₁ = (layer₁ .* dropout_mask) * 2

        layer₂ = layer₁ * weights₁
        error += sum(( label - layer₂ ) .^ 2)
        correct += convert(Int, argmax(layer₂) == argmax(label))

        layer₂∇ = label - layer₂
        layer₁∇ = (layer₂∇ * weights₁') .* relu2deriv.(layer₁)
        layer₁∇ = layer₁∇ .* dropout_mask

        weights₁ += α * (layer₁' * layer₂∇)
        weights₀ += α * (layer₀' * layer₁∇)
    end

    if rem(epoch, 10) == 0
        println("=== Epoch $(epoch)")
        test_size = 1000
        correct /= test_size
        error /= test_size
        @show error correct
        test_error, test_correct = (0.0, 0)
        for i in 1:test_size
            layer₀ = reshape(test_x[:, :, i], 1, 28*28)
            label = reshape(test_labels[:, i], 1, 10)
            layer₁ = relu.(layer₀ * weights₀)
            layer₂ = layer₁ * weights₁
            test_error += sum(( label - layer₂ ) .^ 2)
            test_correct += convert(Int, argmax(layer₂) == argmax(label))
        end
        test_error /= test_size
        test_correct /= test_size
        @show test_error test_correct
    end
end

After some epochs

=== Epoch 10
error = 0.5317613511400684
correct = 0.657
test_error = 1.0676295123314394
test_correct = 0.16
=== Epoch 20
error = 0.4272197286452044
correct = 0.768
test_error = 1.1270212709101741
test_correct = 0.155
=== Epoch 30
error = 0.39200475799496176
correct = 0.798
test_error = 1.1643866582090718
test_correct = 0.167
=== Epoch 40
error = 0.36184040950515167
correct = 0.811
test_error = 1.1864618514078764
test_correct = 0.16
=== Epoch 50
error = 0.34593441115950146
correct = 0.84
test_error = 1.1715820512126938
test_correct = 0.167
=== Epoch 60
error = 0.3204440837764096
correct = 0.868
test_error = 1.2286283167970908
test_correct = 0.159
=== Epoch 70
error = 0.3277778215104479
correct = 0.868
test_error = 1.2249013683103738
test_correct = 0.17
=== Epoch 80
error = 0.32435537191933905
correct = 0.855
test_error = 1.2358375100629013
test_correct = 0.159
=== Epoch 90
error = 0.2999773978624835
correct = 0.884
test_error = 1.2149573376491554
test_correct = 0.164

As you can see, when training, my model did lower the error and improve accuracy, but not on test dataset. The accuracy is terrible. I didn’t understand why. Can someone help me? Thanks a lot.

Nevermind, wrong one-hot-encoding function. After switch to one_hot_encoding(x) = Flux.onehotbatch(x, 0:9) it’s work fine now.

Some epoch after fix

=== Epoch 10
error = 0.44058362405905876
correct = 0.7625
test_error = 0.33555505288043963
test_correct = 0.8598
=== Epoch 20
error = 0.38800343770568324
correct = 0.8105
test_error = 0.2999266758982489
test_correct = 0.8766
=== Epoch 30
error = 0.3698209317837888
correct = 0.8255
test_error = 0.28587186851676244
test_correct = 0.8863
=== Epoch 40
error = 0.35673825132233666
correct = 0.836
test_error = 0.2858764929491342
test_correct = 0.8881
=== Epoch 50
error = 0.340204680724397
correct = 0.85
test_error = 0.2693888361659312
test_correct = 0.8872
=== Epoch 60
error = 0.3217428207530412
correct = 0.8695
test_error = 0.27357312921315236
test_correct = 0.8848