Nevermind, I figured it out.
function roc_auc(ŷ, y, weights; sort_perm = parallel_sort_perm(ŷ))
y = parallel_apply_sort_perm(y, sort_perm)
weights = parallel_apply_sort_perm(weights, sort_perm)
# tpr = true_pos/total_pos
# fpr = false_pos/total_neg
# ROC is tpr vs fpr
thread_pos_weights, thread_neg_weights = parallel_iterate(length(y)) do thread_range
pos_weight = 0.0
neg_weight = 0.0
for i in thread_range
if y[i] > 0.5f0
pos_weight += Float64(weights[i])
else
neg_weight += Float64(weights[i])
end
end
pos_weight, neg_weight
end
total_pos_weight = sum(thread_pos_weights)
total_neg_weight = sum(thread_neg_weights)
thread_aucs = parallel_iterate(length(y)) do thread_range
true_pos_weight = sum(@view thread_pos_weights[Threads.threadid():Threads.nthreads()])
false_pos_weight = sum(@view thread_neg_weights[Threads.threadid():Threads.nthreads()])
auc = 0.0
last_fpr = false_pos_weight / total_neg_weight
for i in thread_range
if y[i] > 0.5f0
true_pos_weight -= Float64(weights[i])
else
false_pos_weight -= Float64(weights[i])
end
fpr = false_pos_weight / total_neg_weight
tpr = true_pos_weight / total_pos_weight
if fpr != last_fpr
auc += (last_fpr - fpr) * tpr
end
last_fpr = fpr
end
auc
end
sum(thread_aucs)
end