metrax.BLEU#

class metrax.BLEU(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array)#

Bases: Metric

Computes the BLEU score for sequence generation.

BLEU measures the similarity between a machine-generated candidate translation and one or more human reference translations, focusing on matching n-grams.

It’s calculated as:

\[BLEU = \text{BP} \cdot \exp\left( \sum_{n=1}^{N} w_n \log p_n \right)\]
where:
  • \(p_n\) is the modified n-gram precision for n-grams of order n.

  • \(N\) is the maximum n-gram order considered (typically 4).

  • \(w_n\) are weights for each order (typically uniform, 1/N).

  • \(\text{BP}\) is the Brevity Penalty.

This implementation uses uniform weights and calculates statistics incrementally.

max_order#

Maximum n-gram order to consider.

Type:

int

matches_by_order#

Accumulated sum of clipped n-gram matches for each order.

Type:

jax.Array

possible_matches_by_order#

Accumulated sum of total n-grams in predictions for each order.

Type:

jax.Array

translation_length#

Accumulated total length of predictions.

Type:

jax.Array

reference_length#

Accumulated total ‘effective’ reference length (closest length match for each prediction).

Type:

jax.Array

__init__(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array) None#

Methods

__init__(max_order, matches_by_order, ...)

compute()

Computes final metrics from intermediate values.

compute_value()

Wraps compute() and returns a values.Value.

empty()

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

from_fun(fun)

Calls cls.from_model_output with the return value from fun.

from_model_output(predictions, references[, ...])

Computes BLEU statistics for a batch of predictions and references.

from_output(name)

Calls cls.from_model_output with model output named name.

merge(other)

Returns Metric that is the accumulation of self and other.

reduce()

Reduces the metric along it first axis by calling _reduce_merge().

replace(**updates)

Returns a new object replacing the specified fields with new values.

Attributes

max_order: int#
matches_by_order: Array#
possible_matches_by_order: Array#
translation_length: Array#
reference_length: Array#
classmethod empty() BLEU#

Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).

classmethod from_model_output(predictions: list[str], references: list[list[str]], max_order: int = 4) BLEU#

Computes BLEU statistics for a batch of predictions and references.

Parameters:
  • predictions – A list of predicted strings. The shape should be (batch_size, ).

  • references – A list of lists of reference strings. The shape should be (batch_size, num_references).

  • max_order – The maximum order of n-grams to consider.

Returns:

A BLEU metric instance containing the statistics for this batch.

Raises:
  • ValueError – If the shapes of predictions and references are

  • incompatible.

merge(other: BLEU) BLEU#

Returns Metric that is the accumulation of self and other.

Parameters:

other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().

Returns:

A new Metric that accumulates the value from both self and other.

compute() Array#

Computes final metrics from intermediate values.

__init__(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array) None#
replace(**updates)#

Returns a new object replacing the specified fields with new values.