metrax.AUCPR#

class metrax.AUCPR(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)#

Bases: Metric

Computes area under the precision-recall curve for binary classification given predictions and labels.

The Precision-Recall curve shows the tradeoff between precision and recall at different classification thresholds. The area under this curve (AUC-PR) provides a single score that represents the model’s ability to identify positive cases across all possible classification thresholds, particularly in imbalanced datasets.

For each threshold \(t\), precision and recall are calculated as:

\[ \begin{align}\begin{aligned}Precision(t) = \frac{TP(t)}{TP(t) + FP(t)}\\Recall(t) = \frac{TP(t)}{TP(t) + FN(t)}\end{aligned}\end{align} \]

The AUC-PR is then computed using interpolation:

\[AUC-PR = \sum_{i=1}^{n-1} (R_{i+1} - R_i) \cdot \frac{P_i + P_{i+1}}{2}\]
where:
  • \(P_i\) is precision at threshold i

  • \(R_i\) is recall at threshold i

  • \(n\) is the number of thresholds

AUC-PR Curve metric have a number of known issues so use it with caution. - PR curves are highly class balance sensitive. - PR is a non-monotonic function and thus its “area” is not directly

proportional to performance.

  • PR-AUC has no standard implementation and different libraries will give different results. Some libraries will interpolate between points, others will assume a step function (or trapezoidal as sklearn does). Some libraries will compute the convex hull of the PR curve, others will not. Because PR is non monotonic, its value is sensitive to the number of samples along the curve (more so than ROC-AUC).

true_positives#

The count of true positive instances from the given data and label at each threshold.

Type:

jax.Array

false_positives#

The count of false positive instances from the given data and label at each threshold.

Type:

jax.Array

false_negatives#

The count of false negative instances from the given data and label at each threshold.

Type:

jax.Array

__init__(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None#

Methods

__init__(true_positives, false_positives, ...)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels[, ...])

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

interpolate_pr_auc()

Interpolation formula inspired by section 4 of Davis & Goadrich 2006.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

true_positives: Array#
false_positives: Array#
false_negatives: Array#
num_thresholds: int#
classmethod empty() AUCPR#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCPR#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

  • num_thresholds – The number of thresholds to use. Default is 200.

Returns:

The area under the precision-recall curve. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: AUCPR) AUCPR#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

interpolate_pr_auc() Array#

Interpolation formula inspired by section 4 of Davis & Goadrich 2006.

https://minds.wisconsin.edu/handle/1793/60482

Note here we derive & use a closed formula not present in the paper as follows:

Precision = TP / (TP + FP) = TP / P

Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted positive) as varying linearly within each interval [A, B] between successive thresholds, we get

Precision slope = dTP / dP

= (TP_B - TP_A) / (P_B - P_A) = (TP - TP_A) / (P - P_A)

Precision = (TP_A + slope * (P - P_A)) / P

The area within the interval is (slope / total_pos_weight) times

int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}

where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in

int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)

Bringing back the factor (slope / total_pos_weight) we’d put aside, we get

slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight

where dTP == TP_B - TP_A.

Note that when P_A == 0 the above calculation simplifies into

int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)

which is really equivalent to imputing constant precision throughout the first bucket having >0 true positives.

Returns:

A float scalar jax.Array that is an approximation of the area under the P-R curve.

Return type:

pr_auc

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.