Flux.jl confusion matrix

I’m using Flux to implement a handwritten digit classifier based on the MNIST dataset (ex https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl).

Is there any way to compute the confusion matrix in order to evaluate the performance of the model (need precision, recall and F1 score).


As a quick hack, you can find a simple implementation of a confusion matrix here: https://github.com/OxoaResearch/j4pr.jl/blob/master/src/lib/libutils.jl (may need a bit of tweaking for Julia >0.6 …)

Thanks! In the end, I implemented it myself.
I also discovered that there’s a confusmat function in MLBase but I’m not sure if/how it works with Flux.

Should work as well. The basic confusion matrix needs just two vectors ( references and predictions). Btw, I believe MLBase is informally deprecated in favour of LearnBase

Oh, I see - makes sense as MLBase wasn’t updated in the last few months.

I’m not sure about using it with Flux - how would it work for example in this model: https://github.com/FluxML/model-zoo/blob/master/vision/mnist/conv.jl

The model you pointed to is just the training. Applying it i.e. inference, should output a vector of 10 posterior probabilities for each sample (10 classes, 0 to 9). From there, the predicted label can be extracted (max posterior probability); the predicted labels can then be fed into the confusion matrix …

Thanks. It seems that some testing is performed throughout the training process, isn’t it? Instead of accuracy, it would be useful to be able to plug-in a full confusion matrix computation:

# Prepare test set (first 1,000 images)
tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu


accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
evalcb = throttle(() -> @show(accuracy(tX, tY)), 10)

Couldn’t say, I’m really not familiar with Flux. Neural models backpropagate errors at sample level, not class level. In practice, they do optimize the confusion matrix by trying to learn the correct classes :slight_smile:

Sorry for reviving this, is MLBase still maintained? I noticed there were some discussions on the github page of merging it into StatsBase.
I was also searching for a confusion matrix implementation.

I guess so as it is part of JuliaStats yet it receives very little attention. I would not really rely on it.

How did you implement it?

Can’t really remember, it was for a paper for my MSc - I think I’ve done it by hand (wrote the code from scratch).

Alright, thanks