metrax.Accuracy#
- class metrax.Accuracy(total: Array, count: Array)#
Bases:
AverageComputes accuracy, which is the frequency with which predictions match labels.
This metric calculates the proportion of correct predictions by comparing predictions and labels element-wise. It is the ratio of the sum of weighted correct predictions to the sum of all corresponding weights. If no sample_weights are provided, weights default to 1 for each element.
The calculation is as follows:
\[\text{Accuracy} = \frac{\sum (\text{weight} \times \text{correct})}{\sum \text{weight}}\]where correct is 1 if prediction == label for an element, and 0 otherwise. weight is the sample_weight for that element, or 1 if no weights are given.
Methods
__init__(total, count)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, labels[, ...])Updates the metric state with new predictions and labels.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
totalcount- classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None) Accuracy#
Updates the metric state with new predictions and labels.
This method computes element-wise equality between predictions and labels. The result of this comparison (a boolean array, treated as 1 for True and 0 for False) is then used to update the metric’s total and count.
- Parameters:
predictions – JAX array of predicted values. Expected to have a shape compatible with labels for element-wise comparison (e.g., (batch_size,), (batch_size, num_classes), (batch_size, sequence_length, num_features)).
labels – JAX array of true values. Expected to have a shape compatible with predictions for element-wise comparison.
sample_weights – Optional JAX array of weights. If provided, it must be broadcastable to the shape of labels (which should also be compatible with predictions’ shape).
- Returns:
An updated instance of Accuracy metric.
- Raises:
ValueError – If JAX operations (like broadcasting or arithmetic) fail due to incompatible shapes or types among predictions, labels, and sample_weights. For instance, if predictions and labels shapes are not identical and not broadcastable to a common shape for comparison, or if sample_weights cannot be broadcast to labels’ shape.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.