I’ve just had a nice needless romp through the finer details of ML implementations in Lux.jl because of my misunderstanding of ROCAnalysis.jl… Turns out I should read about things before using them.
There are some important points about this implementation of the ROC curve:
- The ROC curve is computed over False Negatives and False Positives, instead of True Positives and False Positives. So, your curve will be upside-down and a lower AUC is better.
- It takes target and non-target arrays as input. So, if you had
y_hat as a vector of logits (or probabilities) and y_true as the 1/0 labels, then you’d use:
target = y_hat[y_true .== 1]
nontarget = y_hat[y_true .== 0]
roc(target, nontarget)
If you reverse the arguments, you’ll get a ROC curve that looks normal, but mirrored. Probably not ideal.
Anyways, after I figured all that out, I gave up on more packages and just used this:
function rocauc(y_true, y_pred)
sorted_indices = sortperm(y_pred, rev=true)
sorted_true, sorted_pred = y_true[sorted_indices], y_pred[sorted_indices]
total_positive = Int(sum(y_true))
total_negative = length(y_true) - total_positive
tpr, fpr, thresholds = [], [], []
tp, fp = 0, 0
for i in 1:length(y_true)
isone(sorted_true[i]) ? (tp += 1) : (fp += 1)
push!(tpr, tp / total_positive)
push!(fpr, fp / total_negative)
push!(thresholds, sorted_pred[i])
end
auc = sum((fpr[i] - fpr[i-1]) * (tpr[i] + tpr[i-1]) / 2 for i in 2:length(fpr))
auc, fpr, tpr
end
Which will give you the normal AUC with y_true as described above and y_pred as the associated probabilities. If anyone wants to use it, you can print AUC and plot the ROC curve with:
using Plots
auc, fpr, tpr = rocauc(y_true, y_pred)
println("AUC: ", auc)
plot(fpr, tpr,
label="ROC w/ AUC = $(round(auc, digits=4))",
xlabel="False Positive Rate",
ylabel="True Positive Rate",
title="ROC Curve"
)