 # [ANN] AdversarialPrediction.jl | Easily optimize generic performance metrics in differentiable learning

I am excited to announce a new approach for optimizing generic performance metrics (including non-decomposable metrics) in differentiable learning. AdversarialPrediction.jl provides a Julia implementation of the framework.

Here are some previews of what the framework is capable of. Suppose we want to optimize F-beta score metric in a CNN model:

``````using Flux

model = Chain(
Conv((5, 5), 1=>20, relu), MaxPool((2,2)),
Conv((5, 5), 20=>50, relu), MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(4*4*50, 500), Dense(500, 1), vec
)

@metric FBeta beta
function define(::Type{FBeta}, C::ConfusionMatrix, beta)
return ((1 + beta^2) * C.tp) / (beta^2 * C.ap + C.pp)
end
f2_score = FBeta(2)
special_case_positive!(f2_score)

objective(x, y) = ap_objective(model(x), y, f2_score)
``````

We just need to write the definition of the F-beta metric and integrate it to the network using `ap_objective` instead of the standard cross-entropy objective.

More complicated metric? No problem. Just write the metric inside the `define` function. For example the Cohen’s kappa score:

``````@metric Kappa
function define(::Type{Kappa}, C::ConfusionMatrix)
pe = (C.ap * C.pp + C.an * C.pn) / C.all^2
num = (C.tp + C.tn) / C.all - pe
den = 1 - pe
return num / den
end

kappa = Kappa()
special_case_positive!(kappa)
special_case_negative!(kappa)
``````

Trade-off between two metrics? AdversarialPredcition.jl can handle this as well. For example, optimizing precision given recall is greater than something.

``````# Precision given recall metric
@metric PrecisionGvRecall th
function define(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
return C.tp / C.pp
end
function constraint(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
return C.tp / C.ap >= th
end

precision_gv_recall = PrecisionGvRecall(0.8)
special_case_positive!(precision_gv_recall)
cs_special_case_positive!(precision_gv_recall, true)
``````

Note: It currently supports `Flux v0.9` (tracker based autodiff).

– Rizal

3 Likes