I am excited to announce a new approach for optimizing generic performance metrics (including non-decomposable metrics) in differentiable learning. AdversarialPrediction.jl provides a Julia implementation of the framework.
Paper: [1912.00965] AP-Perf: Incorporating Generic Performance Metrics in Differentiable Learning
Code: https://github.com/rizalzaf/AdversarialPrediction.jl
Example: https://github.com/rizalzaf/AP-examples
Here are some previews of what the framework is capable of. Suppose we want to optimize F-beta score metric in a CNN model:
using Flux
using AdversarialPrediction
import AdversarialPrediction: define, constraint
model = Chain(
Conv((5, 5), 1=>20, relu), MaxPool((2,2)),
Conv((5, 5), 20=>50, relu), MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(4*4*50, 500), Dense(500, 1), vec
)
@metric FBeta beta
function define(::Type{FBeta}, C::ConfusionMatrix, beta)
return ((1 + beta^2) * C.tp) / (beta^2 * C.ap + C.pp)
end
f2_score = FBeta(2)
special_case_positive!(f2_score)
objective(x, y) = ap_objective(model(x), y, f2_score)
Flux.train!(objective, params(model), train_set, ADAM(1e-3))
We just need to write the definition of the F-beta metric and integrate it to the network using ap_objective
instead of the standard cross-entropy objective.
More complicated metric? No problem. Just write the metric inside the define
function. For example the Cohen’s kappa score:
@metric Kappa
function define(::Type{Kappa}, C::ConfusionMatrix)
pe = (C.ap * C.pp + C.an * C.pn) / C.all^2
num = (C.tp + C.tn) / C.all - pe
den = 1 - pe
return num / den
end
kappa = Kappa()
special_case_positive!(kappa)
special_case_negative!(kappa)
Trade-off between two metrics? AdversarialPredcition.jl can handle this as well. For example, optimizing precision given recall is greater than something.
# Precision given recall metric
@metric PrecisionGvRecall th
function define(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
return C.tp / C.pp
end
function constraint(::Type{PrecisionGvRecall}, C::ConfusionMatrix, th)
return C.tp / C.ap >= th
end
precision_gv_recall = PrecisionGvRecall(0.8)
special_case_positive!(precision_gv_recall)
cs_special_case_positive!(precision_gv_recall, true)
Note: It currently supports Flux v0.9
(tracker based autodiff).
– Rizal