metrax.AveragePrecisionAtK#
- class metrax.AveragePrecisionAtK(total: Array, count: Array)#
Bases:
AverageComputes AP@k (average precision at k) metrics.
Average precision at k (AP@k) is a metric used to evaluate the performance of ranking models. It measures the sum of precision at k where the item at the kth rank is relevant, divided by the total number of relevant items.
Given the top \(K\) recommendations, AP@K is calculated as:
\[\begin{split}AP@K = \frac{1}{r} \sum_{k=1}^{K} Precision@k * rel(k) \\ rel(k) = \begin{cases} 1 & \text{if the item at rank } k \text{ is relevant} \\ 0 & \text{otherwise} \end{cases}\end{split}\]Methods
__init__(total, count)average_precision_at_ks(predictions, labels, ks)Computes AP@k (average precision at k) metrics for each of k in ks.
compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, labels, ks)Updates the metric.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
totalcount- classmethod average_precision_at_ks(predictions: Array, labels: Array, ks: Array)#
Computes AP@k (average precision at k) metrics for each of k in ks.
- Parameters:
predictions – A floating point 2D vector representing the prediction generated from the model. The shape should be (batch_size, vocab_size).
labels – A multi-hot encoding of the true label. The shape should be (batch_size, vocab_size).
ks – A 1D vector of integers representing the k’s to compute the MAP@k metrics. The shape should be (|ks|).
- Returns:
Rank-2 tensor of shape (batch, |ks|) containing AP@k metrics.
- classmethod from_model_output(predictions: Array, labels: Array, ks: Array) AveragePrecisionAtK#
Updates the metric.
- Parameters:
predictions – A floating point 2D vector representing the prediction generated from the model. The shape should be (batch_size, vocab_size).
labels – A multi-hot encoding of the true label. The shape should be (batch_size, vocab_size).
ks – A 1D vector of integers representing the k’s to compute the MAP@k metrics. The shape should be (|ks|).
- Returns:
The AveragePrecisionAtK metric. The shape should be (|ks|).
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- replace(**updates)#
Returns a new object replacing the specified fields with new values.