Flux.jl confusion matrix

I’m using Flux to implement a handwritten digit classifier based on the MNIST dataset (ex https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl).

Is there any way to compute the confusion matrix in order to evaluate the performance of the model (need precision, recall and F1 score).

Thanks!

As a quick hack, you can find a simple implementation of a confusion matrix here: j4pr.jl/libutils.jl at master · OxoaResearch/j4pr.jl · GitHub (may need a bit of tweaking for Julia >0.6 …)

Thanks! In the end, I implemented it myself.
I also discovered that there’s a confusmat function in MLBase but I’m not sure if/how it works with Flux.

Should work as well. The basic confusion matrix needs just two vectors ( references and predictions). Btw, I believe MLBase is informally deprecated in favour of LearnBase

Oh, I see - makes sense as MLBase wasn’t updated in the last few months.

I’m not sure about using it with Flux - how would it work for example in this model: https://github.com/FluxML/model-zoo/blob/master/vision/mnist/conv.jl

The model you pointed to is just the training. Applying it i.e. inference, should output a vector of 10 posterior probabilities for each sample (10 classes, 0 to 9). From there, the predicted label can be extracted (max posterior probability); the predicted labels can then be fed into the confusion matrix …

Thanks. It seems that some testing is performed throughout the training process, isn’t it? Instead of accuracy, it would be useful to be able to plug-in a full confusion matrix computation:

# Prepare test set (first 1,000 images)
tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu

...

accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
evalcb = throttle(() -> @show(accuracy(tX, tY)), 10)

Couldn’t say, I’m really not familiar with Flux. Neural models backpropagate errors at sample level, not class level. In practice, they do optimize the confusion matrix by trying to learn the correct classes :slight_smile:

Sorry for reviving this, is MLBase still maintained? I noticed there were some discussions on the github page of merging it into StatsBase.
I was also searching for a confusion matrix implementation.

1 Like

I guess so as it is part of JuliaStats yet it receives very little attention. I would not really rely on it.

How did you implement it?

Can’t really remember, it was for a paper for my MSc - I think I’ve done it by hand (wrote the code from scratch).

Alright, thanks

See Performance Measures · MLJ (alan-turing-institute.github.io)

julia> using CategoricalArrays, MLJ

julia> yhat = rand(1:10, 100)|>CategoricalArray;

julia> y = rand(1:10, 100)|>CategoricalArray;

julia> ConfusionMatrix()(yhat, y)
┌ Warning: The classes are un-ordered,
│ using order: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10].
│ To suppress this warning, consider coercing to OrderedFactor.
└ @ MLJBase C:\Users\usrname\.julia\packages\MLJBase\rQDaq\src\measures\confusion_matrix.jl:122
10×10 Matrix{Int64}:
 1  3  1  3  1  0  1  0  1  1
 1  0  2  0  1  3  0  1  0  0
 1  2  1  0  3  1  0  1  1  1
 0  1  0  1  0  1  2  1  1  4
 1  1  1  1  3  1  2  0  3  3
 1  0  2  1  2  0  1  0  0  1
 0  1  0  1  1  2  1  1  0  1
 1  1  1  1  1  1  2  0  0  1
 2  0  0  2  1  1  1  1  0  0
 3  3  1  0  0  0  0  0  2  0
2 Likes