metrax.MRR#

class metrax.MRR(total: Array, count: Array)#

Bases: Average

Computes Mean Reciprocal Rank (MRR), supporting MRR@k for multiple k values.

MRR is the average of the reciprocal ranks of the first relevant item for a set of queries.

The mean reciprocal rank for a group of queries \(q\) in \(Q\) is defined as follows:

\[MRR = \frac{1}{|Q|} \sum_{q \in Q} RR_q\]

Where \(RR_q\) is the reciprocal rank for query \(q\), defined as:

\[\begin{split}RR_q = \begin{cases} \frac{1}{\text{rank}} & \text{if a revelant item is found} \\ 0 & \text{if no relevant item is found.} \end{cases}\end{split}\]

This implementation assumes binary relevance labels (1 for relevant, 0 for not relevant).

__init__(total: Array, count: Array) None#

Methods

__init__(total, count)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels, ks)

Creates an MRR metric instance from model output, calculating MRR@k for each k.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

total

count

classmethod from_model_output(predictions: Array, labels: Array, ks: Array) MRR#

Creates an MRR metric instance from model output, calculating MRR@k for each k.

Parameters:
  • predictions – A 2D array of prediction scores. Higher scores indicate higher rank. The shape should be (batch_size, vocab_size).

  • labels – A 2D array of binary relevance labels (0 or 1). The shape should be (batch_size, vocab_size).

  • ks – A 1D array of integers representing the k cutoffs. The shape should be (|ks|, ).

Returns:

An MRR metric object. The ‘total’ field will be an array of shape (|ks|, ).

__init__(total: Array, count: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.