Flux.jl confusion matrix

flux

#1

I’m using Flux to implement a handwritten digit classifier based on the MNIST dataset (ex https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl).

Is there any way to compute the confusion matrix in order to evaluate the performance of the model (need precision, recall and F1 score).

Thanks!


#2

As a quick hack, you can find a simple implementation of a confusion matrix here: https://github.com/OxoaResearch/j4pr.jl/blob/master/src/lib/libutils.jl (may need a bit of tweaking for Julia >0.6 …)


#3

Thanks! In the end, I implemented it myself.
I also discovered that there’s a confusmat function in MLBase but I’m not sure if/how it works with Flux.


#4

Should work as well. The basic confusion matrix needs just two vectors ( references and predictions). Btw, I believe MLBase is informally deprecated in favour of LearnBase


#5

Oh, I see - makes sense as MLBase wasn’t updated in the last few months.

I’m not sure about using it with Flux - how would it work for example in this model: https://github.com/FluxML/model-zoo/blob/master/vision/mnist/conv.jl


#6

The model you pointed to is just the training. Applying it i.e. inference, should output a vector of 10 posterior probabilities for each sample (10 classes, 0 to 9). From there, the predicted label can be extracted (max posterior probability); the predicted labels can then be fed into the confusion matrix …


#7

Thanks. It seems that some testing is performed throughout the training process, isn’t it? Instead of accuracy, it would be useful to be able to plug-in a full confusion matrix computation:

# Prepare test set (first 1,000 images)
tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu

...

accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
evalcb = throttle(() -> @show(accuracy(tX, tY)), 10)

#8

Couldn’t say, I’m really not familiar with Flux. Neural models backpropagate errors at sample level, not class level. In practice, they do optimize the confusion matrix by trying to learn the correct classes :slight_smile: