Fast Parallel ROC AUC Calculation

I’m trying to calculate the receiver operator characteristic area under curve for large amounts of data. I’ve parallelized the sorting, but I’ve got too much SARS-CoV2 in me at the moment to figure out the math to parallelize the integration. I did a bit of searching, but could only find single-threaded implementations. Is anyone aware of a highly parallel Julia-language ROC AUC calculation?

function roc_auc(ŷ, y, weights; sort_perm = parallel_sort_perm(ŷ), total_weight = parallel_float64_sum(weights), positive_weight = parallel_float64_sum(y .* weights))
  negative_weight  = total_weight - positive_weight
  true_pos_weight  = positive_weight
  false_pos_weight = negative_weight

  # tpr = true_pos/total_pos
  # fpr = false_pos/total_neg
  # ROC is tpr vs fpr

  auc = 0.0

  last_fpr = false_pos_weight / negative_weight # = 1.0
  for i in sort_perm
    if y[i] > 0.5f0
      true_pos_weight -= Float64(weights[i])
    else
      false_pos_weight -= Float64(weights[i])
    end
    fpr = false_pos_weight / negative_weight
    tpr = true_pos_weight  / positive_weight
    if fpr != last_fpr
      auc += (last_fpr - fpr) * tpr
    end
    last_fpr = fpr
  end

  auc
end

Supporting functions:

# Sample sort
function parallel_sort_perm(arr)
  sample_count = Threads.nthreads() * 20
  if length(arr) < sample_count || Threads.nthreads() == 1
    return sortperm(arr; alg = Base.Sort.MergeSort)
  end

  rng = MersenneTwister(1234);

  samples = sort(map(_ -> arr[rand(rng, 1:length(arr))], 1:sample_count))

  bin_splits = map(thread_i -> samples[Int64(round(thread_i/Threads.nthreads()*sample_count))], 1:(Threads.nthreads() - 1))

  thread_bins_bins = map(_ -> map(_ -> Int64[], 1:Threads.nthreads()), 1:Threads.nthreads())
  Threads.@threads for i in 1:length(arr)
    thread_bins = thread_bins_bins[Threads.threadid()]

    x = arr[i]
    bin_i = Threads.nthreads()
    for k in 1:length(bin_splits)
      if bin_splits[k] > x
        bin_i = k
        break
      end
    end

    push!(thread_bins[bin_i], Int64(i))
  end

  outs = map(_ -> Int64[], 1:Threads.nthreads())

  Threads.@threads for _ in 1:Threads.nthreads()
    my_thread_bins = map(thread_i -> thread_bins_bins[thread_i][Threads.threadid()], 1:Threads.nthreads())

    my_out = Vector{Int64}(undef, sum(length.(my_thread_bins)))

    my_i = 1
    for j = 1:length(my_thread_bins)
      bin = my_thread_bins[j]
      if length(bin) > 0
        my_out[my_i:(my_i + length(bin)-1)] = bin
        my_i += length(bin)
      end
    end

    sort!(my_out; alg = Base.Sort.MergeSort, by = (i -> arr[i]))
    outs[Threads.threadid()] = my_out
  end

  out = Vector{Int64}(undef, length(arr))

  Threads.@threads for _ in 1:Threads.nthreads()
    my_out = outs[Threads.threadid()]

    start_i = sum(length.(outs)[1:Threads.threadid()-1]) + 1

    out[start_i:(start_i+length(my_out)-1)] = my_out
  end

  out
end

# f should be a function that take an indices_range and returns a tuple of reduction values
#
# parallel_iterate will unzip those tuples into a tuple of arrays of reduction values and return that.
function parallel_iterate(f, count)
  thread_results = Vector{Any}(undef, Threads.nthreads())

  Threads.@threads for thread_i in 1:Threads.nthreads()
    start = div((thread_i-1) * count, Threads.nthreads()) + 1
    stop  = div( thread_i    * count, Threads.nthreads())
    thread_results[thread_i] = f(start:stop)
  end

  if isa(thread_results[1], Tuple)
    # Mangling so you get a tuple of arrays.
    Tuple(collect.(zip(thread_results...)))
  else
    thread_results
  end
end

function parallel_float64_sum(arr)
  thread_sums = parallel_iterate(length(arr)) do thread_range
    thread_sum = 0.0
    for i in thread_range
      thread_sum += Float64(arr[i])
    end
    thread_sum
  end
  sum(thread_sums)
end

Nevermind, I figured it out.


function roc_auc(ŷ, y, weights; sort_perm = parallel_sort_perm(ŷ))
  y       = parallel_apply_sort_perm(y, sort_perm)
  weights = parallel_apply_sort_perm(weights, sort_perm)

  # tpr = true_pos/total_pos
  # fpr = false_pos/total_neg
  # ROC is tpr vs fpr

  thread_pos_weights, thread_neg_weights = parallel_iterate(length(y)) do thread_range
    pos_weight = 0.0
    neg_weight = 0.0
    for i in thread_range
      if y[i] > 0.5f0
        pos_weight += Float64(weights[i])
      else
        neg_weight += Float64(weights[i])
      end
    end
    pos_weight, neg_weight
  end

  total_pos_weight = sum(thread_pos_weights)
  total_neg_weight = sum(thread_neg_weights)

  thread_aucs = parallel_iterate(length(y)) do thread_range
    true_pos_weight  = sum(@view thread_pos_weights[Threads.threadid():Threads.nthreads()])
    false_pos_weight = sum(@view thread_neg_weights[Threads.threadid():Threads.nthreads()])

    auc = 0.0

    last_fpr = false_pos_weight / total_neg_weight
    for i in thread_range
      if y[i] > 0.5f0
        true_pos_weight -= Float64(weights[i])
      else
        false_pos_weight -= Float64(weights[i])
      end
      fpr = false_pos_weight / total_neg_weight
      tpr = true_pos_weight  / total_pos_weight
      if fpr != last_fpr
        auc += (last_fpr - fpr) * tpr
      end
      last_fpr = fpr
    end

    auc
  end

  sum(thread_aucs)
end
4 Likes