How to use crossentropy in a multi-class classification Flux.jl neural network?

first-steps
flux

#1

I am having trouble defining the crossentropy loss using Flux.jl.

using Flux,StatsBase
model = Flux.Chain(
  Dense(13*16, 128, relu),
  Dense(128, 64, relu),
  Dense(64, 32, relu),
  Dense(32, 4, relu),
    softmax);

loss(x,y) = crossentropy(model(x),y)
opt = ADAM(params(model))

I have set up the model to try and predict a 4-label classification problem, but I can’t seem to get the loss function to work. What form does y have to be?

For example, my y for a record can be coded as [0.0,1.0,0.0,0.0] but running crossentropy gives Inf (Tracked).


#2

Flux expect y to be matrix. crossentropy(model(x),onehotbatch(y, 0:1)) should work, assuming that y is and Int vector of labels with 0 and 1s.

onehotbatch is a function defined in Flux


#3

The model-zoo contains many example models using crossentropy. You can use them to see how common models and patterns are implemented in Flux


#4

Both Flux.jl and StatsBase exports crossentropy, and using loss(x,y) = Flux.crossentropy(model(x),y) solves the problem