metrax.AUCROC#

class metrax.AUCROC(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)#

Bases: Metric

Computes Area Under the ROC curve (AUC-ROC).

Computes area under the receiver operation characteristic curve for binary classification given predictions and labels.

The ROC curve shows the tradeoff between the true positive rate (TPR) and false positive rate (FPR) at different classification thresholds. The area under this curve (AUC-ROC) provides a single score that represents the model’s ability to discriminate between positive and negative cases across all possible classification thresholds, regardless of class imbalance.

For each threshold \(t\), TPR and FPR are calculated as:

\[ \begin{align}\begin{aligned}TPR(t) = \frac{TP(t)}{TP(t) + FN(t)}\\FPR(t) = \frac{FP(t)}{FP(t) + TN(t)}\end{aligned}\end{align} \]

The AUC-ROC is then computed using the trapezoidal rule:

\[AUC-ROC = \int_{0}^{1} TPR(FPR^{-1}(x)) dx\]

A score of 1 represents perfect classification, while 0.5 represents random guessing.

true_positives#

The count of true positive instances from the given data and label at each threshold.

Type:

jax.Array

false_positives#

The count of false positive instances from the given data and label at each threshold.

Type:

jax.Array

total_count#

The count of every data point.

__init__(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None#

Methods

__init__(true_positives, true_negatives, ...)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels[, ...])

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

true_positives: Array#
true_negatives: Array#
false_positives: Array#
false_negatives: Array#
num_thresholds: int#
classmethod empty() AUCROC#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCROC#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

  • num_thresholds – The number of thresholds to use. Default is 200.

Returns:

The area under the receiver operation characteristic curve. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: AUCROC) AUCROC#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.