Binary classification with Flux

I originally asked this question on the Julia slack, but it occurred to me that this could be a better place to ask:

I’ve used Flux successfully for regression and multi-label classification in the past, but today I got stuck on a very simple binary classification problem. I naively assumed that this would be a simple matter of starting with a multi-label classification (using model-zoo/mlp.jl at master · FluxML/model-zoo · GitHub as a starting point), removing one-hot encoding of labels, reducing the number of outputs from the final layer to 1, and replacing logitcrossentropy with Flux.Losses.logitbinarycrossentropy as my loss function. This turned out not to work. My model refused to learn at all.

I made this gist where I have tried to modify the above example from the Flux model zoo to do a binary classification (guessing whether a MNIST image is a 0 or not):

It is still not working. What am I doing wrong here?

I’m not a Deep Learning expert.
I think the problem is vanishing gradient (I have tried with a learning rate of 10^7 and the model works).
Moreover I think that data are not balanced, maybe you could try dividing the numbers into even and odd.
Sorry for my english.

1 Like

Thank you for your reply! Your English is at least as good as mine (also not a native speaker).

Unbalanced data is not the problem. When i one-hot-encode the two labels (zero and non-zero) and revert to using two outputs from the final layer and Flux.losses.logitcrossentropy, everything works well.

Were you ever able to find an answer to this problem? I’m having the same.

Can you post a MWE? The OP didn’t test different learning rates and many other configurations. For example, using one output to accommodate logitbinarycrossentropy halves the number of parameters in the second layer and may prevent the model from learning effectively.

1 Like

At epoch 0 the code below outputs:
loss,accuracy = (0.04393502f0, 0.498842289209262)
At epoch 1:
loss,accuracy = (0.03497529f0, 0.7889908256880734)
At epoch N:
loss,accuracy = (0.03497529f0, 0.7889908256880734)

The loss never updates after the first.
I realized that the model is outputting 0 every time, and that 0.788… corresponds to the percentage of non-matches in the training set. Could the problem be that the classes are too imbalanced?

function loss_and_accuracy(data_loader, model, device)
    acc = 0
    ls = 0.0f0
    num = 0
    for (x, y) in data_loader # calculate for each batch
        x, y = device(x), device(y)
        ŷ = model(x)
        ls += Flux.Losses.logitbinarycrossentropy(ŷ, y)
        guesses = [v[1] > 0.5 ? 1 : 0 for v ∈ ŷ]
        acc += sum(guesses .== y) * 1 / size(x, 2)
        num +=  size(x, 2)
    end
    return ls / num, acc / num
end

numfeatures = size(train_data.data[1],1)
model  = Chain(
        Dense(numfeatures, numfeatures,relu),
        Dense(numfeatures, 1,σ),
        )

ps = Flux.params(model)
opt = ADAM(3e-4)
loss(ŷ, y) = Flux.Losses.logitbinarycrossentropy(ŷ, y)

for epoch in 1:5
    for (x, y) in train_data
        gs = Flux.gradient(ps) do
            loss(model(x), y)
        end
        Flux.Optimise.update!(opt, ps, gs)

        loss_and_accuracy(train_loader, model, device)  
    end
end

logitbinarycrossentropy fuses the sigmoid and loss, so the second layer should not have a sigmoid activation.

It’s difficult to tell what other issues could be present or run this locally without train_data itself, can you provide that as well (i.e. a full executable example)?

1 Like

Here’s a sample of the train_data. All values fall between 1.0f-12 and 1.0

julia> train_data.data[1]
22815×1526 Matrix{Float32}

julia> train_data.data[1][:,1]
22815-element Vector{Float32}:
 1.0f-12
 0.0017195814
 0.0017195814
 0.023636289
 0.028990168
 0.044405
 0.043595545
 0.06990052
 0.055057727
 0.05537212
 0.09786336
 0.10126311
 0.10529525
 0.09850332
 0.120765366
 0.23463665
 0.19305018
 0.19305018
 ⋮
 0.9924742
 0.9953276
 1.0
 0.87539566
 0.8523039
 0.8453158
 0.83677983
 0.82271576
 0.84766895
 0.8416806
 0.83194155
 0.81610835
 0.7944943
 0.84520817
 0.812333
 0.79442024


Please have a read through PSA: make it easier to help you. Having a single sample of the data doesn’t help much with having a MWE that others can run locally and debug. Also, have you tried any of the suggestions mentioned upthread?