metrax.Accuracy#

class metrax.Accuracy(total: Array, count: Array)#

Bases: Average

Computes accuracy, which is the frequency with which predictions match labels.

This metric calculates the proportion of correct predictions by comparing predictions and labels element-wise. It is the ratio of the sum of weighted correct predictions to the sum of all corresponding weights. If no sample_weights are provided, weights default to 1 for each element.

The calculation is as follows:

\[\text{Accuracy} = \frac{\sum (\text{weight} \times \text{correct})}{\sum \text{weight}}\]

where correct is 1 if prediction == label for an element, and 0 otherwise. weight is the sample_weight for that element, or 1 if no weights are given.

__init__(total: Array, count: Array) None#

Methods

__init__(total, count)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels[, ...])

Updates the metric state with new predictions and labels.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

total

count

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) Accuracy#

Updates the metric state with new predictions and labels.

This method computes element-wise equality between predictions and labels. The result of this comparison (a boolean array, treated as 1 for True and 0 for False) is then used to update the metric’s total and count.

Parameters:
  • predictions – JAX array of predicted values. Expected to have a shape compatible with labels for element-wise comparison (e.g., (batch_size,), (batch_size, num_classes), (batch_size, sequence_length, num_features)).

  • labels – JAX array of true values. Expected to have a shape compatible with predictions for element-wise comparison.

  • sample_weights – Optional JAX array of weights. If provided, it must be broadcastable to the shape of labels (which should also be compatible with predictions’ shape).

Returns:

An updated instance of Accuracy metric.

Raises:

ValueError – If JAX operations (like broadcasting or arithmetic) fail due to incompatible shapes or types among predictions, labels, and sample_weights. For instance, if predictions and labels shapes are not identical and not broadcastable to a common shape for comparison, or if sample_weights cannot be broadcast to labels’ shape.

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.