Metrax Metrics

Contents

Metrax Metrics#

class metrax.AUCPR(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)#

Bases: Metric

Computes area under the precision-recall curve for binary classification given predictions and labels.

The Precision-Recall curve shows the tradeoff between precision and recall at different classification thresholds. The area under this curve (AUC-PR) provides a single score that represents the model’s ability to identify positive cases across all possible classification thresholds, particularly in imbalanced datasets.

For each threshold \(t\), precision and recall are calculated as:

\[ \begin{align}\begin{aligned}Precision(t) = \frac{TP(t)}{TP(t) + FP(t)}\\Recall(t) = \frac{TP(t)}{TP(t) + FN(t)}\end{aligned}\end{align} \]

The AUC-PR is then computed using interpolation:

\[AUC-PR = \sum_{i=1}^{n-1} (R_{i+1} - R_i) \cdot \frac{P_i + P_{i+1}}{2}\]
where:
  • \(P_i\) is precision at threshold i

  • \(R_i\) is recall at threshold i

  • \(n\) is the number of thresholds

AUC-PR Curve metric have a number of known issues so use it with caution. - PR curves are highly class balance sensitive. - PR is a non-monotonic function and thus its “area” is not directly

proportional to performance.

  • PR-AUC has no standard implementation and different libraries will give different results. Some libraries will interpolate between points, others will assume a step function (or trapezoidal as sklearn does). Some libraries will compute the convex hull of the PR curve, others will not. Because PR is non monotonic, its value is sensitive to the number of samples along the curve (more so than ROC-AUC).

true_positives#

The count of true positive instances from the given data and label at each threshold.

Type:

jax.Array

false_positives#

The count of false positive instances from the given data and label at each threshold.

Type:

jax.Array

false_negatives#

The count of false negative instances from the given data and label at each threshold.

Type:

jax.Array

true_positives: Array#
false_positives: Array#
false_negatives: Array#
num_thresholds: int#
classmethod empty() AUCPR#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCPR#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

  • num_thresholds – The number of thresholds to use. Default is 200.

Returns:

The area under the precision-recall curve. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: AUCPR) AUCPR#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

interpolate_pr_auc() Array#

Interpolation formula inspired by section 4 of Davis & Goadrich 2006.

https://minds.wisconsin.edu/handle/1793/60482

Note here we derive & use a closed formula not present in the paper as follows:

Precision = TP / (TP + FP) = TP / P

Modeling all of TP (true positive), FP (false positive) and their sum P = TP + FP (predicted positive) as varying linearly within each interval [A, B] between successive thresholds, we get

Precision slope = dTP / dP

= (TP_B - TP_A) / (P_B - P_A) = (TP - TP_A) / (P - P_A)

Precision = (TP_A + slope * (P - P_A)) / P

The area within the interval is (slope / total_pos_weight) times

int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P} int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}

where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in

int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)

Bringing back the factor (slope / total_pos_weight) we’d put aside, we get

slope * [dTP + intercept * log(P_B / P_A)] / total_pos_weight

where dTP == TP_B - TP_A.

Note that when P_A == 0 the above calculation simplifies into

int_A^B{Precision.dTP} = int_A^B{slope * dTP} = slope * (TP_B - TP_A)

which is really equivalent to imputing constant precision throughout the first bucket having >0 true positives.

Returns:

A float scalar jax.Array that is an approximation of the area under the P-R curve.

Return type:

pr_auc

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.AUCROC(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int)#

Bases: Metric

Computes Area Under the ROC curve (AUC-ROC).

Computes area under the receiver operation characteristic curve for binary classification given predictions and labels.

The ROC curve shows the tradeoff between the true positive rate (TPR) and false positive rate (FPR) at different classification thresholds. The area under this curve (AUC-ROC) provides a single score that represents the model’s ability to discriminate between positive and negative cases across all possible classification thresholds, regardless of class imbalance.

For each threshold \(t\), TPR and FPR are calculated as:

\[ \begin{align}\begin{aligned}TPR(t) = \frac{TP(t)}{TP(t) + FN(t)}\\FPR(t) = \frac{FP(t)}{FP(t) + TN(t)}\end{aligned}\end{align} \]

The AUC-ROC is then computed using the trapezoidal rule:

\[AUC-ROC = \int_{0}^{1} TPR(FPR^{-1}(x)) dx\]

A score of 1 represents perfect classification, while 0.5 represents random guessing.

true_positives#

The count of true positive instances from the given data and label at each threshold.

Type:

jax.Array

false_positives#

The count of false positive instances from the given data and label at each threshold.

Type:

jax.Array

total_count#

The count of every data point.

true_positives: Array#
true_negatives: Array#
false_positives: Array#
false_negatives: Array#
num_thresholds: int#
classmethod empty() AUCROC#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, num_thresholds: int = 200) AUCROC#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

  • num_thresholds – The number of thresholds to use. Default is 200.

Returns:

The area under the receiver operation characteristic curve. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: AUCROC) AUCROC#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, true_negatives: Array, false_positives: Array, false_negatives: Array, num_thresholds: int) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.Accuracy(total: Array, count: Array)#

Bases: Average

Computes accuracy, which is the frequency with which predictions match labels.

This metric calculates the proportion of correct predictions by comparing predictions and labels element-wise. It is the ratio of the sum of weighted correct predictions to the sum of all corresponding weights. If no sample_weights are provided, weights default to 1 for each element.

The calculation is as follows:

\[\text{Accuracy} = \frac{\sum (\text{weight} \times \text{correct})}{\sum \text{weight}}\]

where correct is 1 if prediction == label for an element, and 0 otherwise. weight is the sample_weight for that element, or 1 if no weights are given.

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) Accuracy#

Updates the metric state with new predictions and labels.

This method computes element-wise equality between predictions and labels. The result of this comparison (a boolean array, treated as 1 for True and 0 for False) is then used to update the metric’s total and count.

Parameters:
  • predictions – JAX array of predicted values. Expected to have a shape compatible with labels for element-wise comparison (e.g., (batch_size,), (batch_size, num_classes), (batch_size, sequence_length, num_features)).

  • labels – JAX array of true values. Expected to have a shape compatible with predictions for element-wise comparison.

  • sample_weights – Optional JAX array of weights. If provided, it must be broadcastable to the shape of labels (which should also be compatible with predictions’ shape).

Returns:

An updated instance of Accuracy metric.

Raises:

ValueError – If JAX operations (like broadcasting or arithmetic) fail due to incompatible shapes or types among predictions, labels, and sample_weights. For instance, if predictions and labels shapes are not identical and not broadcastable to a common shape for comparison, or if sample_weights cannot be broadcast to labels’ shape.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.Average(total: Array, count: Array)#

Bases: Average

Average Metric inherits clu.metrics.Average and performs safe division.

classmethod from_model_output(values: Array, sample_weights: Array | None = None) Average#

Updates the metric.

Parameters:
  • values – A floating point 1D vector representing the values. The shape

  • be (should)

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated Average metric.

compute() Array#

Computes final metrics from intermediate values.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.AveragePrecisionAtK(total: Array, count: Array)#

Bases: Average

Computes AP@k (average precision at k) metrics.

Average precision at k (AP@k) is a metric used to evaluate the performance of ranking models. It measures the sum of precision at k where the item at the kth rank is relevant, divided by the total number of relevant items.

Given the top \(K\) recommendations, AP@K is calculated as:

\[\begin{split}AP@K = \frac{1}{r} \sum_{k=1}^{K} Precision@k * rel(k) \\ rel(k) = \begin{cases} 1 & \text{if the item at rank } k \text{ is relevant} \\ 0 & \text{otherwise} \end{cases}\end{split}\]
classmethod average_precision_at_ks(predictions: Array, labels: Array, ks: Array)#

Computes AP@k (average precision at k) metrics for each of k in ks.

Parameters:
  • predictions – A floating point 2D vector representing the prediction generated from the model. The shape should be (batch_size, vocab_size).

  • labels – A multi-hot encoding of the true label. The shape should be (batch_size, vocab_size).

  • ks – A 1D vector of integers representing the k’s to compute the MAP@k metrics. The shape should be (|ks|).

Returns:

Rank-2 tensor of shape (batch, |ks|) containing AP@k metrics.

classmethod from_model_output(predictions: Array, labels: Array, ks: Array) AveragePrecisionAtK#

Updates the metric.

Parameters:
  • predictions – A floating point 2D vector representing the prediction generated from the model. The shape should be (batch_size, vocab_size).

  • labels – A multi-hot encoding of the true label. The shape should be (batch_size, vocab_size).

  • ks – A 1D vector of integers representing the k’s to compute the MAP@k metrics. The shape should be (|ks|).

Returns:

The AveragePrecisionAtK metric. The shape should be (|ks|).

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.BLEU(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array)#

Bases: Metric

Computes the BLEU score for sequence generation.

BLEU measures the similarity between a machine-generated candidate translation and one or more human reference translations, focusing on matching n-grams.

It’s calculated as:

\[BLEU = \text{BP} \cdot \exp\left( \sum_{n=1}^{N} w_n \log p_n \right)\]
where:
  • \(p_n\) is the modified n-gram precision for n-grams of order n.

  • \(N\) is the maximum n-gram order considered (typically 4).

  • \(w_n\) are weights for each order (typically uniform, 1/N).

  • \(\text{BP}\) is the Brevity Penalty.

This implementation uses uniform weights and calculates statistics incrementally.

max_order#

Maximum n-gram order to consider.

Type:

int

matches_by_order#

Accumulated sum of clipped n-gram matches for each order.

Type:

jax.Array

possible_matches_by_order#

Accumulated sum of total n-grams in predictions for each order.

Type:

jax.Array

translation_length#

Accumulated total length of predictions.

Type:

jax.Array

reference_length#

Accumulated total ‘effective’ reference length (closest length match for each prediction).

Type:

jax.Array

max_order: int#
matches_by_order: Array#
possible_matches_by_order: Array#
translation_length: Array#
reference_length: Array#
classmethod empty() BLEU#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: list[str], references: list[list[str]], max_order: int = 4) BLEU#

Computes BLEU statistics for a batch of predictions and references.

Parameters:
  • predictions – A list of predicted strings. The shape should be (batch_size, ).

  • references – A list of lists of reference strings. The shape should be (batch_size, num_references).

  • max_order – The maximum order of n-grams to consider.

Returns:

A BLEU metric instance containing the statistics for this batch.

Raises:
  • ValueError – If the shapes of predictions and references are

  • incompatible.

merge(other: BLEU) BLEU#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.CosineSimilarity(total: Array, count: Array)#

Bases: Average

Calculates the Cosine Similarity between two arrays.

The Cosine Similarity is defined as the dot product of the vectors divided by the product of their magnitudes (norms).

\[cos_{sim}(x,y) = \frac{x \cdot y}{||x|| * ||y||}\]
classmethod from_model_output(predictions: Array, targets: Array, axis: int = -1) CosineSimilarity#

Creates a CosineSimilarity instance.

Parameters:
  • predictions – A floating point array of the predictions. The shape should be (batch_size,).

  • targets – A floating point array of the targets. The shape should be (batch_size,).

  • axis – The axis to compute the norm over.

Returns:

A CosineSimilarity instance.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.DCGAtK(total: Array, count: Array)#

Bases: Average

Computes Discounted Cumulative Gain at k metric.

DCG tells how good a list of search results or recommendations is, based on the relevance of the items and their positions in the list.

This implementation calculates \(DCG@k\) based on the following formula:

\[DCG@K(y, s) = \sum_{i=1}^{K} \text{gain}(y_i) \times \text{rank_discount}(\text{rank}(s_i))\]

where

  • \(y_i\) is the relevance label from the labels,

  • \(s_i\) is its score from the prediction,

  • \(\text{rank}(s_i)\) is the 1-based rank of item \(i\).

  • \(\text{gain}(y_i) = 2^{y_i} - 1\).

  • \(\text{rank_discount}(\text{rank}(s_i)) = \frac{1}{\log_2(\text{rank}(s_i) + 1)}\).

We get the final formula:

\[DCG@K(y, s) = \sum_{i=1}^{K} \frac{2^{y_i} - 1}{\log_2(\text{rank}(s_i) + 1)}\]
classmethod from_model_output(predictions: Array, labels: Array, ks: Array) DCGAtK#

Creates a DCGAtK metric instance from model output.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.Dice(intersection: Array, sum_pred: Array, sum_true: Array)#

Bases: Metric

Computes the Dice coefficient between y_true and y_pred.

Dice is a similarity measure used to measure overlap between two samples. A Dice score of 1 indicates perfect overlap; 0 indicates no overlap.

The formula is:

\[\text{Dice} = \frac{2 \cdot \sum (y_{\text{true}} \cdot y_{\text{pred}})} {\sum y_{\text{true}} + \sum y_{\text{pred}} + \epsilon}\]
intersection#

Sum of element-wise product between y_true and y_pred.

Type:

jax.Array

sum_true#

Sum of y_true across all examples.

Type:

jax.Array

sum_pred#

Sum of y_pred across all examples.

Type:

jax.Array

intersection: Array#
sum_pred: Array#
sum_true: Array#
classmethod empty() Dice#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array) Dice#

Updates the metric.

Parameters:
  • predictions – A floating point vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

Returns:

Updated Dice metric.

merge(other: Dice) Dice#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(intersection: Array, sum_pred: Array, sum_true: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.FBetaScore(true_positives: Array, false_positives: Array, false_negatives: Array, beta: float = 1.0)#

Bases: Metric

F-Beta score Metric class.

Computes the F-Beta score for the binary classification given ‘predictions’ and ‘labels’.

Formula for F-Beta Score:

b2 = beta ** 2 f_beta_score = ((1 + b2) * (precision * recall)) / (precision * b2 + recall)

F-Beta turns into the F1 Score when beta = 1.0

true_positives#

The count of true positive instances from the given data, label, and threshold.

Type:

jax.Array

false_positives#

The count of false positive instances from the given data, label, and threshold.

Type:

jax.Array

false_negatives#

The count of false negative instances from the given data, label, and threshold.

Type:

jax.Array

beta#

The beta value used in the F-Score metric.

Type:

float

true_positives: Array#
false_positives: Array#
false_negatives: Array#
beta: float = 1.0#
classmethod empty() FBetaScore#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, beta: float = 1.0, threshold: float = 0.5) FBetaScore#

Updates the metric.

Note: When only predictions and labels are given, the score calculated is the F1 score if the FBetaScore beta value has not been previously modified.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • beta – beta value to use in the F-Score metric. A floating number.

  • threshold – threshold value to use in the F-Score metric. A floating number.

Returns:

The updated FBetaScore object.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible. If the beta or threshold are invalid

  • values

  • an error is raised as well.

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_positives: Array, false_negatives: Array, beta: float = 1.0) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.IoU(total: Array, count: Array)#

Bases: Average

Measures Intersection over Union (IoU) for semantic segmentation.

The general formula for IoU for a single class is: $IoU_{class} = frac{TP}{TP + FP + FN}$ where TP, FP, FN are True Positives, False Positives, and False Negatives.

Per-Batch Processing: For each input batch, a mean IoU is calculated. This involves: 1. Aggregating TP, FP, and FN pixel counts for each specified target class

(from the required target_class_ids list) across all samples within the batch.

  1. Computing IoU for each of these classes using the batch-aggregated counts: $IoU_{class} = frac{TP}{TP + FP + FN + epsilon}$.

  2. Averaging these per-class IoU scores to get a single value for the batch. - If target_class_ids is empty, an array of zeros of shape (B,)

    (where B is batch size) is produced by _calculate_iou.

    • Otherwise, a scalar jnp.ndarray (shape ()) representing the mean IoU is produced.

Accumulation & Final Metric: This class inherits from base.Average. It accumulates the results from per-batch processing and compute() returns the final mean IoU as a scalar jnp.ndarray (shape ()).

classmethod from_model_output(predictions: Array, targets: Array, num_classes: int, target_class_ids: Array, from_logits: bool = False, epsilon: float = 1e-07) IoU#

Creates an IoU instance from a batch of model outputs.

Per-batch processing: 1. Preprocesses predictions and targets into integer label masks of

shape (B, H, W). (B: batch size, H: height, W: width).

  1. Calls _calculate_iou using the provided target_class_ids to compute the batch’s mean IoU.

Parameters:
  • predictionsjax.Array. Model predictions. - If from_logits is True: shape (B, H, W, C) (C: num_classes). - If from_logits is False: shape (B, H, W) or (B, H, W, 1).

  • targetsjax.Array. Ground truth segmentation masks. Shape (B, H, W) or (B, H, W, 1), integer class labels.

  • num_classes – Total number of distinct classes (C). Integer.

  • target_class_ids – An array of integer class IDs for which to compute IoU.

  • from_logitsbool. If True, predictions are logits and argmax is applied. Default is False.

  • epsilonfloat. Small value for stable IoU calculation. Default is 1e-7.

Returns:

An IoU metric instance updated with the IoU score from this batch.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.MAE(total: Array, count: Array)#

Bases: Average

Computes the mean absolute error for regression problems given predictions and labels.

The mean absolute error without sample weights is defined as:

\[MAE = \frac{1}{N} \sum_{i=1}^{N} |y_i - \hat{y}_i|\]

When sample weights \(w_i\) are provided, the weighted mean absolute error is:

\[MAE = \frac{\sum_{i=1}^{N} w_i|y_i - \hat{y}_i|}{\sum_{i=1}^{N} w_i}\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(w_i\) are sample weights

  • \(N\) is the number of samples

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) MAE#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated MAE metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.MRR(total: Array, count: Array)#

Bases: Average

Computes Mean Reciprocal Rank (MRR), supporting MRR@k for multiple k values.

MRR is the average of the reciprocal ranks of the first relevant item for a set of queries.

The mean reciprocal rank for a group of queries \(q\) in \(Q\) is defined as follows:

\[MRR = \frac{1}{|Q|} \sum_{q \in Q} RR_q\]

Where \(RR_q\) is the reciprocal rank for query \(q\), defined as:

\[\begin{split}RR_q = \begin{cases} \frac{1}{\text{rank}} & \text{if a revelant item is found} \\ 0 & \text{if no relevant item is found.} \end{cases}\end{split}\]

This implementation assumes binary relevance labels (1 for relevant, 0 for not relevant).

classmethod from_model_output(predictions: Array, labels: Array, ks: Array) MRR#

Creates an MRR metric instance from model output, calculating MRR@k for each k.

Parameters:
  • predictions – A 2D array of prediction scores. Higher scores indicate higher rank. The shape should be (batch_size, vocab_size).

  • labels – A 2D array of binary relevance labels (0 or 1). The shape should be (batch_size, vocab_size).

  • ks – A 1D array of integers representing the k cutoffs. The shape should be (|ks|, ).

Returns:

An MRR metric object. The ‘total’ field will be an array of shape (|ks|, ).

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.MSE(total: Array, count: Array)#

Bases: Average

Computes the mean squared error for regression problems given predictions and labels.

The mean squared error without sample weights is defined as:

\[MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2\]

When sample weights \(w_i\) are provided, the weighted mean squared error is:

\[MSE = \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(w_i\) are sample weights

  • \(N\) is the number of samples

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) MSE#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated MSE metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.MSLE(total: Array, count: Array)#

Bases: Average

Computes the mean squared logarithmic error for regression problems given predictions and labels.

The mean squared logarithmic error is defined as:

\[MSLE = \frac{1}{N} \sum_{i=1}^{N} (ln(y_i + 1) - ln(\hat{y}_i + 1))^2\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(N\) is the number of samples

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) MSLE#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated MSLE metric. The shape should be a single scalar.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.NDCGAtK(total: Array, count: Array)#

Bases: DCGAtK

Computes Normalized Discounted Cumulative Gain at k (NDCG@k) metrics.

NDCG@k normalizes DCG@k by the Ideal DCG@k (IDCG@k), which is the DCG@k score of a perfectly ranked list (items sorted by their true relevance).

This implementation calculates \(NDCG@k\) based on the following formula:

\[NDCG@k = \frac{DCG@k}{IDCG@k}\]

where

  • If \(IDCG@k\) is 0, then \(NDCG@k\) is defined as 0.

  • The \(DCG@k\) calculation uses \(exp2\) gain (\(2^{\text{relevance}} - 1\)) and standard logarithmic discount (\(\frac{1}{\log_2(\text{rank} + 1)}\)).

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.Perplexity(aggregate_crossentropy: Array, num_samples: Array)#

Bases: Metric

Computes perplexity for sequence generation.

Perplexity is a measurement of how well a probability distribution predicts a sample. It is defined as the exponentiation of the cross-entropy. A low perplexity indicates the probability distribution is good at predicting the sample.

For language models, it can be interpreted as the weighted average branching factor of the model - how many equally likely words can be selected at each step.

Given a sequence of \(N\) tokens, perplexity is calculated as:

\[Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)\]

When sample weights \(w_i\) are provided:

\[Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)\]
where:
  • \(P(x_i|x_{<i})\) is the predicted probability of token \(x_i\) given previous tokens

  • \(w_i\) are sample weights

  • \(N\) is the sequence length

Lower perplexity indicates better prediction - the model is less “perplexed” by the data.

aggregate_crossentropy: Array#
num_samples: Array#
classmethod empty() Perplexity#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, from_logits: bool = False) Perplexity#

Updates the metric.

Parameters:
  • predictions – A floating point tensor representing the prediction

  • (batch_size (generated from the model. The shape should be)

  • seq_len

:param : :param vocab_size).: :param labels: True value. The shape should be (batch_size, seq_len). :param sample_weights: An optional tensor representing the

weight of each token. The shape should be (batch_size, seq_len).

Parameters:

from_logits – Whether the predictions are logits. If True, the predictions are converted to probabilities using a softmax. If False, all values outside of [0, 1] are clipped to 0 or 1.

Returns:

Updated Perplexity metric.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: Perplexity) Perplexity#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(aggregate_crossentropy: Array, num_samples: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.Precision(true_positives: Array, false_positives: Array)#

Bases: Metric

Computes precision for binary classification given predictions and labels.

It is calculated as:

\[Precision = \frac{TP}{TP + FP}\]
where:
  • TP (True Positives): Number of correctly predicted positive cases

  • FP (False Positives): Number of incorrectly predicted positive cases

A threshold parameter (default 0.5) is used to convert probability predictions to binary predictions.

true_positives#

The count of true positive instances from the given data, label, and threshold.

Type:

jax.Array

false_positives#

The count of false positive instances from the given data, label, and threshold.

Type:

jax.Array

true_positives: Array#
false_positives: Array#
classmethod empty() Precision#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Precision#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • threshold – The threshold to use for the binary classification.

Returns:

Updated Precision metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: Precision) Precision#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_positives: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.PrecisionAtK(total: Array, count: Array)#

Bases: TopKRankingMetric

Computes P@k (precision at k) metrics.

Precision at k (P@k) is a metric that measures the proportion of relevant items found in the top k recommendations. It answers the question: “Out of the K items recommended, how many are actually relevant?”

Given the top \(K\) recommendations, P@K is calculated as:

\[Precision@K = \frac{\text{Number of relevant items in top K}}{K}\]
__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.PSNR(total: Array, count: Array)#

Bases: Average

PSNR (Peak Signal-to-Noise Ratio) Metric.

This class calculates the Peak Signal-to-Noise Ratio (PSNR) between two images to measure the quality of a reconstructed image compared to a reference.

\[\text{PSNR}(I, J) = 10 \cdot \log_{10} \left( \frac{\max(I)^2}{\text{MSE}(I, J)} \right)\]
Where:
  • \(\max(I)\) is the maximum possible pixel value of the input image.

  • \(\text{MSE}(I, J)\) is the mean squared error between images

\(I\) and \(J\).

classmethod from_model_output(predictions: Array, targets: Array, max_val: float) PSNR#

Computes PSNR for a batch of images and creates an PSNR metric instance.

Parameters:
  • predictions – A JAX array of predicted images, with shape (batch, H, W, C).

  • targets – A JAX array of ground truth images, with shape (batch, H, W, C).

  • max_val – The maximum possible pixel value (dynamic range) of the images (e.g., 1.0 for float images in [0,1], 255 for uint8 images).

Returns:

A PSNR instance containing per‑image PSNR values.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.RMSE(total: Array, count: Array)#

Bases: MSE

Computes the root mean squared error for regression problems given predictions and labels.

The root mean squared error without sample weights is defined as:

\[RMSE = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2}\]

When sample weights \(w_i\) are provided, the weighted root mean squared error is:

\[RMSE = \sqrt{\frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i}}\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(w_i\) are sample weights

  • \(N\) is the number of samples

compute() Array#

Computes final metrics from intermediate values.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.RMSLE(total: Array, count: Array)#

Bases: MSLE

Computes the root mean squared logarithmic error for regression problems given predictions and labels.

The root mean squared logarithmic error is defined as:

\[RMSLE = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (ln(y_i + 1) - ln(\hat{y}_i + 1))^2 }\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(N\) is the number of samples

compute() Array#

Computes final metrics from intermediate values.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.RSQUARED(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array)#

Bases: Metric

Computes the r-squared score of a scalar or a batch of tensors.

R-squared is a measure of how well the regression model fits the data. It measures the proportion of the variance in the dependent variable that is explained by the independent variable(s). It is defined as 1 - SSE / SST, where SSE is the sum of squared errors and SST is the total sum of squares.

\[R^2 = 1 - \frac{SSE}{SST}\]
where:
\[SSE = \sum_{i=1}^{N} (y_i - \hat{y}_i)^2\]
\[SST = \sum_{i=1}^{N} (y_i - \bar{y})^2\]

When sample weights \(w_i\) are provided:

\[R^2 = 1 - \frac{\sum_{i=1}^{N} w_i(y_i - \hat{y}_i)^2}{\sum_{i=1}^{N} w_i(y_i - \bar{y})^2}\]
where:
  • \(y_i\) are true values

  • \(\hat{y}_i\) are predictions

  • \(\bar{y}\) is the mean of true values

  • \(w_i\) are sample weights

  • \(N\) is the number of samples

The score ranges from -∞ to 1, where 1 indicates perfect prediction and 0 indicates that the model performs no better than a horizontal line.

total: Array#
count: Array#
sum_of_squared_error: Array#
sum_of_squared_label: Array#
classmethod empty() RSQUARED#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) RSQUARED#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector representing the prediction generated from the model. The shape should be (batch_size,).

  • labels – True value. The shape should be (batch_size,).

  • sample_weights – An optional floating point 1D vector representing the weight of each sample. The shape should be (batch_size,).

Returns:

Updated RSQUARED metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: RSQUARED) RSQUARED#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes the r-squared score.

Since we don’t know the mean of the labels before we aggregate all of the data, we will manipulate the formula to be: sst = sum_i (x_i - mean)^2

= sum_i (x_i^2 - 2 x_i mean + mean^2) = sum_i x_i^2 - 2 mean sum_i x_i + N * mean^2 = sum_i x_i^2 - 2 mean * N * mean + N * mean^2 = sum_i x_i^2 - N * mean^2

Returns:

The r-squared score.

__init__(total: Array, count: Array, sum_of_squared_error: Array, sum_of_squared_label: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.SpearmanRankCorrelation(predictions: Array, labels: Array)#

Bases: Metric

Computes the Spearman rank correlation coefficient.

The Spearman rank correlation coefficient measures the monotonic relationship between two variables. It is defined as the Pearson correlation coefficient between the ranked variables.

\[\rho = 1 - \frac{6 \sum d_i^2}{n(n^2 - 1)}\]
where:
  • \(d_i\) is the difference between the ranks of each observation

  • \(n\) is the number of observations

This implementation accumulates all predictions and labels to compute the exact ranks upon calling compute().

Warning

For very large datasets, this may lead to Out-of-Memory (OOM) errors.

predictions#

Accumulated predictions.

Type:

jax.Array

labels#

Accumulated labels.

Type:

jax.Array

predictions: Array#
labels: Array#
classmethod empty() SpearmanRankCorrelation#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, **kwargs) SpearmanRankCorrelation#

Creates a Metric from model outputs.

merge(other: SpearmanRankCorrelation) SpearmanRankCorrelation#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(predictions: Array, labels: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.Recall(true_positives: Array, false_negatives: Array)#

Bases: Metric

Computes recall for binary classification given predictions and labels.

It is calculated as:

\[Recall = \frac{TP}{TP + FN}\]
where:
  • TP (True Positives): Number of correctly predicted positive cases

  • FN (False Negatives): Number of incorrectly predicted negative cases

A threshold parameter (default 0.5) is used to convert probability predictions to binary predictions.

true_positives#

The count of true positive instances from the given data, label, and threshold.

Type:

jax.Array

false_negatives#

The count of false negative instances from the given data, label, and threshold.

Type:

jax.Array

true_positives: Array#
false_negatives: Array#
classmethod empty() Recall#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, threshold: float = 0.5) Recall#

Updates the metric.

Parameters:
  • predictions – A floating point 1D vector whose values are in the range [0, 1]. The shape should be (batch_size,).

  • labels – True value. The value is expected to be 0 or 1. The shape should be (batch_size,).

  • threshold – The threshold to use for the binary classification.

Returns:

Updated Recall metric. The shape should be a single scalar.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: Recall) Recall#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(true_positives: Array, false_negatives: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.RecallAtK(total: Array, count: Array)#

Bases: TopKRankingMetric

Computes R@k (recall at k) metrics.

Recall at k (R@k) is a metric that measures the proportion of relevant items that are found in the top k recommendations, out of the total number of relevant items for a given user/query. It answers the question: “Out of all the items that are truly relevant, how many did we find in the top K?”

Given the top \(K\) recommendations, R@K is calculated as:

\[Recall@K = \frac{\text{Number of relevant items in top K}}{\text{Total number of relevant items}}\]
__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.RougeL(total_precision: Array, total_recall: Array, total_f1: Array, num_examples: Array)#

Bases: RougeBase

Computes macro-averaged ROUGE-L recall, precision, and F1-score.

ROUGE-L measures the longest common subsequence (LCS) between a prediction and a reference. This metric calculates ROUGE-L precision, recall, and F1-score for each individual prediction compared against its single corresponding reference. These per-instance scores are then averaged.

How ROUGE-L scores are calculated for each individual prediction-reference pair:

For a single prediction P and reference R:

\[LCS(P, R) = \text{length of the Longest Common Subsequence}\]
\[\text{Recall}_{\text{LCS}} = \frac{LCS(P, R)}{|R|}\]
\[\text{Precision}_{\text{LCS}} = \frac{LCS(P, R)}{|P|}\]
\[\text{F1}_{\text{LCS}} = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}\]

Final Macro-Averaged Metrics:

\[\text{MacroAvgPrecision} = \frac{\text{total_precision}}{\text{num_examples}}\]
\[\text{MacroAvgRecall} = \frac{\text{total_recall}}{\text{num_examples}}\]
\[\text{MacroAvgF1} = \frac{\text{total_f1}}{\text{num_examples}}\]
total_precision#

Accumulated sum of LCS precision scores from each instance.

Type:

jax.Array

total_recall#

Accumulated sum of LCS recall scores from each instance.

Type:

jax.Array

total_f1#

Accumulated sum of LCS F1 scores from each instance.

Type:

jax.Array

num_examples#

The number of instances (prediction-reference pairs) processed.

Type:

jax.Array

classmethod empty(**kwargs) RougeL#

Creates an empty Rouge metric. Implemented by subclasses.

__init__(total_precision: Array, total_recall: Array, total_f1: Array, num_examples: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.RougeN(total_precision: Array, total_recall: Array, total_f1: Array, num_examples: Array, order: int)#

Bases: RougeBase

Computes macro-averaged ROUGE-N recall, precision, and F1-score.

This metric first calculates ROUGE-N precision, recall, and F1-score for each individual prediction compared against its single corresponding reference. ROUGE-N scores are based on the number of overlapping n-grams (sequences of n words) between the prediction and the reference text. These per-instance precision, recall, and F1-scores are then averaged across all instances in the dataset/batch.

How ROUGE-N scores are calculated for each individual prediction-reference pair:

\[\text{Precision} = \frac{N_o}{N_p}\]
\[\text{Recall} = \frac{N_o}{N_r}\]
\[\text{F1} = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}\]
where:
  • \(N_o\) be the number of n-grams that overlap between the prediction

and the reference. - \(N_p\) be the total number of n-grams in the prediction. - \(N_r\) be the total number of n-grams in the reference.

Final Macro-Averaged Metrics:

\[\text{MacroAvgPrecision} = \frac{\text{total_precision}}{\text{num_examples}}\]
\[\text{MacroAvgRecall} = \frac{\text{total_recall}}{\text{num_examples}}\]
\[\text{MacroAvgF1} = \frac{\text{total_f1}}{\text{num_examples}}\]
order#

The specific ‘N’ in ROUGE-N (e.g., 1 for ROUGE-1, 2 for ROUGE-2).

Type:

int

total_precision#

Accumulated sum of precision scores from each instance.

Type:

jax.Array

total_recall#

Accumulated sum of recall scores from each instance.

Type:

jax.Array

total_f1#

Accumulated sum of f1 scores from each instance.

Type:

jax.Array

num_examples#

The number of instances (prediction-reference pairs) processed.

Type:

jax.Array

order: int#
classmethod empty(order: int = 2) RougeN#

Creates an empty Rouge metric. Implemented by subclasses.

__init__(total_precision: Array, total_recall: Array, total_f1: Array, num_examples: Array, order: int) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.SNR(total: Array, count: Array)#

Bases: Average

SNR (Signal-to-Noise Ratio) Metric for audio.

This class calculates the Signal-to-Noise Ratio (SNR) in decibels (dB) between a predicted audio signal and a ground truth audio signal, and averages it over a dataset.

The SNR is defined as:

\[SNR_{dB} = 10 \cdot \log_{10} \left( \frac{P_{signal}}{P_{noise}} \right)\]
Where:
  • \(P_{signal}\) is the power of the ground truth signal (target). By default (zero_mean=False), this is the mean of the squared target values. If zero_mean=True, it’s the variance of the target values.

  • \(P_{noise}\) is the power of the noise component, which is defined as the difference between the target and preds (target - preds). By default (zero_mean=False), this is the mean of the squared noise values. If zero_mean=True, it’s the variance of the noise values.

classmethod from_model_output(predictions: Array, targets: Array, zero_mean: bool = False) SNR#

Computes SNR for a batch of audio signals and creates an SNR metric instance.

Parameters:
  • predictions – A JAX array of predicted audio signals.

  • targets – A JAX array of ground truth audio signals.

  • zero_mean – If True, subtracts the mean from the signal and noise before calculating their respective powers.

Returns:

An SNR instance containing the SNR value for the current batch, ready for averaging.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.SSIM(total: Array, count: Array)#

Bases: Average

SSIM (Structural Similarity Index Measure) Metric.

This class calculates the structural similarity between predicted and target images and averages it over a dataset. SSIM is a perception-based model that considers changes in structural information, luminance, and contrast.

The general SSIM formula considers three components: luminance (l),

contrast (c), and structure (s):

\[SSIM(x, y) = [l(x, y)]^\alpha \cdot [c(x, y)]^\beta \cdot [s(x, y)]^\gamma\]
Where:
  • Luminance comparison: \(l(x, y) = \frac{2\mu_x\mu_y + c_1}{\mu_x^2 + \mu_y^2 + c_1}\)

  • Contrast comparison: \(c(x, y) = \frac{2\sigma_x\sigma_y + c_2}{\sigma_x^2 + \sigma_y^2 + c_2}\)

  • Structure comparison: \(s(x, y) = \frac{\sigma_{xy} + c_3}{\sigma_x\sigma_y + c_3}\)

This implementation uses a common simplified form where \(\alpha = \beta = \gamma = 1\) and \(c_3 = c_2 / 2\).

This leads to the combined formula:

\[SSIM(x, y) = \frac{(2\mu_x\mu_y + c_1)(2\sigma_{xy} + c_2)}{(\mu_x^2 + \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}\]
In these formulas:
  • \(\mu_x\) and \(\mu_y\) are the local means of \(x\) and

\(y\). - \(\sigma_x^2\) and \(\sigma_y^2\) are the local variances of \(x\) and \(y\). - \(\sigma_{xy}\) is the local covariance of \(x\) and \(y\). - \(c_1 = (K_1 L)^2\) and \(c_2 = (K_2 L)^2\) are stabilization constants,

where \(L\) is the dynamic range of pixel values, and \(K_1, K_2\) are small constants (e.g., 0.01 and 0.03).

classmethod from_model_output(predictions: Array, targets: Array, max_val: float, filter_size: int = 11, filter_sigma: float = 1.5, k1: float = 0.01, k2: float = 0.03) SSIM#

Computes SSIM for a batch of images and creates an SSIM metric instance.

This method takes batches of predicted and target images, calculates their SSIM values, and then initializes an SSIM metric object suitable for aggregation across multiple batches.

Parameters:
  • predictions – A JAX array of predicted images, with shape (batch, height, width, channels).

  • targets – A JAX array of ground truth images, with shape (batch, height, width, channels).

  • max_val – The maximum possible pixel value (dynamic range) of the images (e.g., 1.0 for float images in [0,1], 255 for uint8 images).

  • filter_size – The size of the Gaussian filter window used in SSIM calculation (default is 11).

  • filter_sigma – The standard deviation of the Gaussian filter (default is 1.5).

  • k1 – SSIM stability constant for the luminance term (default is 0.01).

  • k2 – SSIM stability constant for the contrast/structure term (default is 0.03).

Returns:

An SSIM instance containing the SSIM values for the current batch, ready for averaging.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

class metrax.WER(total: Array, count: Array)#

Bases: Average

Computes Word Error Rate (WER) for speech recognition or text generation tasks.

Word Error Rate measures the edit distance between reference texts and predictions, normalized by the length of the reference texts. It is calculated as:

\[WER = \frac{S + D + I}{N}\]
where:
  • S is the number of substitutions

  • D is the number of deletions

  • I is the number of insertions

  • N is the number of words in the reference

A lower WER indicates better performance, with 0 being perfect.

This implementation accepts both pre-tokenized inputs (lists of tokens) and untokenized strings. When strings are provided, they are tokenized by splitting on whitespace.

classmethod from_model_output(predictions: list[str], references: list[str]) WER#

Updates the metric.

Parameters:
  • prediction – Either a string or a list of tokens in the predicted sequence.

  • reference – Either a string or a list of tokens in the reference sequence.

Returns:

New WER metric instance.

Raises:

ValueError – If inputs are not properly formatted or are empty.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.

AUCPR(true_positives, false_positives, ...)

Computes area under the precision-recall curve for binary classification given predictions and labels.

AUCROC(true_positives, true_negatives, ...)

Computes Area Under the ROC curve (AUC-ROC).

Accuracy(total, count)

Computes accuracy, which is the frequency with which predictions match labels.

Average(total, count)

Average Metric inherits clu.metrics.Average and performs safe division.

AveragePrecisionAtK(total, count)

Computes AP@k (average precision at k) metrics.

BLEU(max_order, matches_by_order, ...)

Computes the BLEU score for sequence generation.

DCGAtK(total, count)

Computes Discounted Cumulative Gain at k metric.

Dice(intersection, sum_pred, sum_true)

Computes the Dice coefficient between y_true and y_pred.

FBetaScore(true_positives, false_positives, ...)

F-Beta score Metric class.

IoU(total, count)

Measures Intersection over Union (IoU) for semantic segmentation.

MAE(total, count)

Computes the mean absolute error for regression problems given predictions and labels.

MRR(total, count)

Computes Mean Reciprocal Rank (MRR), supporting MRR@k for multiple k values.

MSE(total, count)

Computes the mean squared error for regression problems given predictions and labels.

NDCGAtK(total, count)

Computes Normalized Discounted Cumulative Gain at k (NDCG@k) metrics.

Perplexity(aggregate_crossentropy, num_samples)

Computes perplexity for sequence generation.

Precision(true_positives, false_positives)

Computes precision for binary classification given predictions and labels.

PrecisionAtK(total, count)

Computes P@k (precision at k) metrics.

PSNR(total, count)

PSNR (Peak Signal-to-Noise Ratio) Metric.

RMSE(total, count)

Computes the root mean squared error for regression problems given predictions and labels.

RSQUARED(total, count, sum_of_squared_error, ...)

Computes the r-squared score of a scalar or a batch of tensors.

Recall(true_positives, false_negatives)

Computes recall for binary classification given predictions and labels.

RecallAtK(total, count)

Computes R@k (recall at k) metrics.

RougeL(total_precision, total_recall, ...)

Computes macro-averaged ROUGE-L recall, precision, and F1-score.

RougeN(total_precision, total_recall, ...)

Computes macro-averaged ROUGE-N recall, precision, and F1-score.

SNR(total, count)

SNR (Signal-to-Noise Ratio) Metric for audio.

SSIM(total, count)

SSIM (Structural Similarity Index Measure) Metric.

WER(total, count)

Computes Word Error Rate (WER) for speech recognition or text generation tasks.