I am excited to announce a new approach for optimizing generic performance metrics (including non-decomposable metrics) in differentiable learning. AdversarialPrediction.jl provides a Julia implementation of the framework.

Paper: https://arxiv.org/abs/1912.00965

Code: https://github.com/rizalzaf/AdversarialPrediction.jl

Example: https://github.com/rizalzaf/AP-examples

Here are some previews of what the framework is capable of. Suppose we want to optimize F-beta score metric in a CNN model:

```
using Flux
using AdversarialPrediction
import AdversarialPrediction: define, constraint
model = Chain(
Conv((5, 5), 1=>20, relu), MaxPool((2,2)),
Conv((5, 5), 20=>50, relu), MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(4*4*50, 500), Dense(500, 1), vec
)
@metric FBeta beta
function define(::Type{FBeta}, C::ConfusionMatrix, beta)
return ((1 + beta^2) * C.tp) / (beta^2 * C.ap + C.pp)
end
f2_score = FBeta(2)
special_case_positive!(f2_score)
objective(x, y) = ap_objective(model(x), y, f2_score)
Flux.train!(objective, params(model), train_set, ADAM(1e-3))
```

We just need to write the definition of the F-beta metric and integrate it to the network using `ap_objective`

instead of the standard cross-entropy objective.

More complicated metric? No problem. Just write the metric inside the `define`

function. For example the Cohen’s kappa score:

```
@metric Kappa
function define(::Type{Kappa}, C::ConfusionMatrix)
pe = (C.ap * C.pp + C.an * C.pn) / C.all^2
num = (C.tp + C.tn) / C.all - pe
den = 1 - pe
return num / den
end
kappa = Kappa()
special_case_positive!(kappa)
special_case_negative!(kappa)
```

Trade-off between two metrics? AdversarialPredcition.jl can handle this as well. For example, optimizing precision given recall is greater than something.

```
# Precision given recall metric
@metric PrecisionGvRecall th
function define(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
return C.tp / C.pp
end
function constraint(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
return C.tp / C.ap >= th
end
precision_gv_recall = PrecisionGvRecall(0.8)
special_case_positive!(precision_gv_recall)
cs_special_case_positive!(precision_gv_recall, true)
```

Note: It currently supports `Flux v0.9`

(tracker based autodiff).

– Rizal