How can I get the loss functions in a multi-class classification Flux.jl neural network to work with Flux.onehotbatch?

I am trying to set up a model to predict a 8-label classification problem. Suppose the labels are ycat = ["a", "b", "c", "d", "e", "f", "g"] and I have a sample set of labels to train & test on: ysample = ["a", "a", "g", "d", ...].

I was told I should build a onehot-form of the labels, if I want to build a classification model.

ycat_len = length(ycat)
output = Flux.onehotbatch(ysample, ycat)

I am aware that this question has a near duplicate in How to use crossentropy in a multi-class classification Flux.jl neural network?. However, the in that question seems to be a 1-D array, while it is a 2-D matrix in my case:

> output
8×5000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:
 0  0  0  0  0  0  0  0  0  0  0  0  0  …  0  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  1  0  0  0  1  0  0     1  0  0  0  0  0  0  0  0  1  1  0
 0  0  0  0  1  0  0  0  0  0  0  0  0     0  0  0  0  0  1  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  1  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  1  0  1  0     0  0  0  1  0  0  1  1  0  0  0  1
 0  0  1  1  0  0  0  1  0  0  0  0  1  …  0  0  1  0  0  0  0  0  1  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  1  0  0  1  0  0  0  0     0  0  0  0  1  0  0  0  0  0  0  0

While my input has the form

> input
1×5000 Array{FeaturedGraph{SimpleWeightedGraph{Int64,Float32},Array{Float32,2}},2}:

Using the following setup has resulted in catastrophically low accuracy rates, much worse than guessing labels at random.

using Flux
using Flux: @epochs

train_input  = input[1:num_train];
train_output = output[:, 1:num_train];
train_data   = zip(train_input, train_output);

model = Chain(CNConv(num_features=>atom_fea_len), 
              [CNConv(atom_fea_len=>atom_fea_len) for i in 1:num_conv-1]..., 
              CNMeanPool(crys_fea_len, 0.1), 
              [Dense(crys_fea_len, crys_fea_len, softplus) for i in 1:num_hidden_layers-1]...,
              Dense(crys_fea_len, ycat_len), softmax);

loss_kl(x, y) = Flux.kldivergence(model(x), y)
evalcb() = @show(mean(loss_kl.(test_input, test_output)))

opt = ADAM(0.01)
@epochs num_epochs Flux.train!(loss_kl, params(model), train_data, opt, cb=Flux.throttle(evalcb, 5))

A minor dig into the causes of this bad performance reveals, that the loss function is outputting a single value rather than a 1x5000 matrix. In other words, rather than having a loss function value for every element in our training set, we get a single condensed value which has no practical meaning. How can I fix this?

Are you looking at the value logged by evalcb? That should be a single number because mean aggregates across all dimensions by default…

As for the performance evaluation, it’s difficult to tell what’s going on without example input data and your custom layer definitions. Have you tried using Flux’s (logit) cross-entropy instead of plain kl-divergence? Depending on how the loss curve looks, reducing the learning rate may help as well.

Sorry for not clarifying. When I checked, it is the loss_kl.(test_input, test_output) part that is outputting a single value rather than a matrix/array. mean() simply takes this single value and returns it wholesale.

I would very much appreciate a way to make loss_kl broadcast row/col-wise rather than element-wise. Thanks for your attention!