Fast Parallel ROC AUC Calculation

Nevermind, I figured it out.


function roc_auc(ŷ, y, weights; sort_perm = parallel_sort_perm(ŷ))
  y       = parallel_apply_sort_perm(y, sort_perm)
  weights = parallel_apply_sort_perm(weights, sort_perm)

  # tpr = true_pos/total_pos
  # fpr = false_pos/total_neg
  # ROC is tpr vs fpr

  thread_pos_weights, thread_neg_weights = parallel_iterate(length(y)) do thread_range
    pos_weight = 0.0
    neg_weight = 0.0
    for i in thread_range
      if y[i] > 0.5f0
        pos_weight += Float64(weights[i])
      else
        neg_weight += Float64(weights[i])
      end
    end
    pos_weight, neg_weight
  end

  total_pos_weight = sum(thread_pos_weights)
  total_neg_weight = sum(thread_neg_weights)

  thread_aucs = parallel_iterate(length(y)) do thread_range
    true_pos_weight  = sum(@view thread_pos_weights[Threads.threadid():Threads.nthreads()])
    false_pos_weight = sum(@view thread_neg_weights[Threads.threadid():Threads.nthreads()])

    auc = 0.0

    last_fpr = false_pos_weight / total_neg_weight
    for i in thread_range
      if y[i] > 0.5f0
        true_pos_weight -= Float64(weights[i])
      else
        false_pos_weight -= Float64(weights[i])
      end
      fpr = false_pos_weight / total_neg_weight
      tpr = true_pos_weight  / total_pos_weight
      if fpr != last_fpr
        auc += (last_fpr - fpr) * tpr
      end
      last_fpr = fpr
    end

    auc
  end

  sum(thread_aucs)
end
4 Likes