metrax.Perplexity#

class metrax.Perplexity(aggregate_crossentropy: Array, num_samples: Array)#

Bases: Metric

Computes perplexity for sequence generation.

Perplexity is a measurement of how well a probability distribution predicts a sample. It is defined as the exponentiation of the cross-entropy. A low perplexity indicates the probability distribution is good at predicting the sample.

For language models, it can be interpreted as the weighted average branching factor of the model - how many equally likely words can be selected at each step.

Given a sequence of \(N\) tokens, perplexity is calculated as:

\[Perplexity = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log P(x_i|x_{<i})\right)\]

When sample weights \(w_i\) are provided:

\[Perplexity = \exp\left(-\frac{\sum_{i=1}^{N} w_i\log P(x_i|x_{<i})}{\sum_{i=1}^{N} w_i}\right)\]
where:
  • \(P(x_i|x_{<i})\) is the predicted probability of token \(x_i\) given previous tokens

  • \(w_i\) are sample weights

  • \(N\) is the sequence length

Lower perplexity indicates better prediction - the model is less “perplexed” by the data.

__init__(aggregate_crossentropy: Array, num_samples: Array) None#

Methods

__init__(aggregate_crossentropy, num_samples)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, labels[, ...])

Updates the metric.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

aggregate_crossentropy: Array#
num_samples: Array#
classmethod empty() Perplexity#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: Array, labels: Array, sample_weights: Array | None = None, from_logits: bool = False) Perplexity#

Updates the metric.

Parameters:
  • predictions – A floating point tensor representing the prediction

  • (batch_size (generated from the model. The shape should be)

  • seq_len

:param : :param vocab_size).: :param labels: True value. The shape should be (batch_size, seq_len). :param sample_weights: An optional tensor representing the

weight of each token. The shape should be (batch_size, seq_len).

Parameters:

from_logits – Whether the predictions are logits. If True, the predictions are converted to probabilities using a softmax. If False, all values outside of [0, 1] are clipped to 0 or 1.

Returns:

Updated Perplexity metric.

Raises:
  • ValueError – If type of labels is wrong or the shapes of predictions

  • and labels are incompatible.

merge(other: Perplexity) Perplexity#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(aggregate_crossentropy: Array, num_samples: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.