[ANN] AdversarialPrediction.jl | Easily optimize generic performance metrics in differentiable learning

I am excited to announce a new approach for optimizing generic performance metrics (including non-decomposable metrics) in differentiable learning. AdversarialPrediction.jl provides a Julia implementation of the framework.

Paper: https://arxiv.org/abs/1912.00965
Code: https://github.com/rizalzaf/AdversarialPrediction.jl
Example: https://github.com/rizalzaf/AP-examples

Here are some previews of what the framework is capable of. Suppose we want to optimize F-beta score metric in a CNN model:

using Flux
using AdversarialPrediction
import AdversarialPrediction: define, constraint

model = Chain(
  Conv((5, 5), 1=>20, relu), MaxPool((2,2)),
  Conv((5, 5), 20=>50, relu), MaxPool((2,2)),
  x -> reshape(x, :, size(x, 4)),
  Dense(4*4*50, 500), Dense(500, 1), vec
)      

@metric FBeta beta
function define(::Type{FBeta}, C::ConfusionMatrix, beta)
    return ((1 + beta^2) * C.tp) / (beta^2 * C.ap + C.pp)  
end   
f2_score = FBeta(2)
special_case_positive!(f2_score)

objective(x, y) = ap_objective(model(x), y, f2_score)
Flux.train!(objective, params(model), train_set, ADAM(1e-3))

We just need to write the definition of the F-beta metric and integrate it to the network using ap_objective instead of the standard cross-entropy objective.

More complicated metric? No problem. Just write the metric inside the define function. For example the Cohen’s kappa score:

@metric Kappa
function define(::Type{Kappa}, C::ConfusionMatrix)
    pe = (C.ap * C.pp + C.an * C.pn) / C.all^2
    num = (C.tp + C.tn) / C.all - pe
    den = 1 - pe
    return num / den
end  

kappa = Kappa()
special_case_positive!(kappa)
special_case_negative!(kappa)

Trade-off between two metrics? AdversarialPredcition.jl can handle this as well. For example, optimizing precision given recall is greater than something.

# Precision given recall metric
@metric PrecisionGvRecall th
function define(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
    return C.tp / C.pp
end   
function constraint(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
    return C.tp / C.ap >= th
end   

precision_gv_recall = PrecisionGvRecall(0.8)
special_case_positive!(precision_gv_recall)
cs_special_case_positive!(precision_gv_recall, true)

Note: It currently supports Flux v0.9 (tracker based autodiff).

– Rizal

3 Likes