metrax.BLEU#
- class metrax.BLEU(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array)#
Bases:
MetricComputes the BLEU score for sequence generation.
BLEU measures the similarity between a machine-generated candidate translation and one or more human reference translations, focusing on matching n-grams.
It’s calculated as:
\[BLEU = \text{BP} \cdot \exp\left( \sum_{n=1}^{N} w_n \log p_n \right)\]- where:
\(p_n\) is the modified n-gram precision for n-grams of order n.
\(N\) is the maximum n-gram order considered (typically 4).
\(w_n\) are weights for each order (typically uniform, 1/N).
\(\text{BP}\) is the Brevity Penalty.
This implementation uses uniform weights and calculates statistics incrementally.
- matches_by_order#
Accumulated sum of clipped n-gram matches for each order.
- Type:
jax.Array
- possible_matches_by_order#
Accumulated sum of total n-grams in predictions for each order.
- Type:
jax.Array
- translation_length#
Accumulated total length of predictions.
- Type:
jax.Array
- reference_length#
Accumulated total ‘effective’ reference length (closest length match for each prediction).
- Type:
jax.Array
- __init__(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array) None#
Methods
__init__(max_order, matches_by_order, ...)compute()Computes final metrics from intermediate values.
compute_value()Wraps compute() and returns a values.Value.
empty()Returns an empty instance (i.e. .merge(Metric.empty()) is a no-op).
from_fun(fun)Calls cls.from_model_output with the return value from fun.
from_model_output(predictions, references[, ...])Computes BLEU statistics for a batch of predictions and references.
from_output(name)Calls cls.from_model_output with model output named name.
merge(other)Returns Metric that is the accumulation of self and other.
reduce()Reduces the metric along it first axis by calling _reduce_merge().
replace(**updates)Returns a new object replacing the specified fields with new values.
Attributes
- matches_by_order: Array#
- possible_matches_by_order: Array#
- translation_length: Array#
- reference_length: Array#
- classmethod from_model_output(predictions: list[str], references: list[list[str]], max_order: int = 4) BLEU#
Computes BLEU statistics for a batch of predictions and references.
- Parameters:
predictions – A list of predicted strings. The shape should be (batch_size, ).
references – A list of lists of reference strings. The shape should be (batch_size, num_references).
max_order – The maximum order of n-grams to consider.
- Returns:
A BLEU metric instance containing the statistics for this batch.
- Raises:
ValueError – If the shapes of predictions and references are
incompatible. –
- merge(other: BLEU) BLEU#
Returns Metric that is the accumulation of self and other.
- Parameters:
other – A Metric whose intermediate values should be accumulated onto the values of self. Note that in a distributed setting, other will typically be the output of a jax.lax parallel operator and thus have a dimension added to the dataclass returned by .from_model_output().
- Returns:
A new Metric that accumulates the value from both self and other.
- compute() Array#
Computes final metrics from intermediate values.
- __init__(max_order: int, matches_by_order: Array, possible_matches_by_order: Array, translation_length: Array, reference_length: Array) None#
- replace(**updates)#
Returns a new object replacing the specified fields with new values.