How to use crossentropy in a multi-class classification Flux.jl neural network?

I am having trouble defining the crossentropy loss using Flux.jl.

using Flux,StatsBase
model = Flux.Chain(
  Dense(13*16, 128, relu),
  Dense(128, 64, relu),
  Dense(64, 32, relu),
  Dense(32, 4, relu),

loss(x,y) = crossentropy(model(x),y)
opt = ADAM(params(model))

I have set up the model to try and predict a 4-label classification problem, but I can’t seem to get the loss function to work. What form does y have to be?

For example, my y for a record can be coded as [0.0,1.0,0.0,0.0] but running crossentropy gives Inf (Tracked).

Flux expect y to be matrix. crossentropy(model(x),onehotbatch(y, 0:1)) should work, assuming that y is and Int vector of labels with 0 and 1s.

onehotbatch is a function defined in Flux

The model-zoo contains many example models using crossentropy. You can use them to see how common models and patterns are implemented in Flux

Both Flux.jl and StatsBase exports crossentropy, and using loss(x,y) = Flux.crossentropy(model(x),y) solves the problem