metrax.Perplexity#
- class metrax.Perplexity(aggregate_crossentropy: Array, num_samples: Array)#
Bases:
MetricComputes perplexity for sequence generation.
Perplexity is a measurement of how well a probability distribution predicts a sample. It is defined as the exponentiation of the cross-entropy. A low perplexity indicates the probability distribution is good at predicting the sample.
For language models, it can be interpreted as the weighted average branching factor of the model - how many equally likely words can be selected at each step.
Given a sequence of \(N\) tokens, perplexity is calculated as:
\[Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)\]When sample weights \(w_i\) are provided:
\[Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)\]- where:
\(P(x_i|x_{<i})\) is the predicted probability of token \(x_i\) given previous tokens
\(w_i\) are sample weights
\(N\) is the sequence length
Lower perplexity indicates better prediction - the model is less “perplexed” by the data.
Methods
__init__(aggregate_crossentropy, num_samples)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, labels[, ...])Updates the metric.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
- aggregate_crossentropy: Array#
- num_samples: Array#
- classmethod empty() Perplexity#
Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
- classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, from_logits: bool = False) Perplexity#
Updates the metric.
- Parameters:
predictions – A floating point tensor representing the prediction
(batch_size (generated from the model. The shape should be)
seq_len
:param : :param vocab_size).: :param labels: True value. The shape should be (batch_size, seq_len). :param sample_weights: An optional tensor representing the
weight of each token. The shape should be (batch_size, seq_len).
- Parameters:
from_logits – Whether the predictions are logits. If True, the predictions are converted to probabilities using a softmax. If False, all values outside of [0, 1] are clipped to 0 or 1.
- Returns:
Updated Perplexity metric.
- Raises:
ValueError – If type of labels is wrong or the shapes of predictions
and labels are incompatible. –
- merge(other: Perplexity) Perplexity#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array#
Computes final metrics from intermediate values.
- replace(**updates)#
Returns a new object replacing the specified fields with new values.