diff --git a/metrics/bleu/bleu.py b/metrics/bleu/bleu.py index 38a10c3b..34f3e482 100644 --- a/metrics/bleu/bleu.py +++ b/metrics/bleu/bleu.py @@ -60,6 +60,7 @@ references: list of lists of or just a list of references for each translation. tokenizer : approach used for tokenizing `predictions` and `references`. The default tokenizer is `tokenizer_13a`, a minimal tokenization approach that is equivalent to `mteval-v13a`, used by WMT. + You can also pass tokenizer="13a" to use the default WMT tokenizer. This can be replaced by any function that takes a string as input and returns a list of tokens as output. max_order: Maximum n-gram order to use when computing BLEU score. smooth: Whether or not to apply Lin et al. 2004 smoothing. @@ -112,8 +113,27 @@ def _info(self): ], ) - def _compute(self, predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False): + def _compute(self, predictions, references, tokenizer=None, max_order=4, smooth=False): # if only one reference is provided make sure we still use list of lists + if len(predictions) != len(references): + raise ValueError( + f"Predictions and references must have the same length, got {len(predictions)} and {len(references)}." + ) + if len(predictions) == 0: + raise ValueError("Predictions and references must be non-empty.") + + if tokenizer is None: + tokenizer = Tokenizer13a() # Prevent instantiated defaults to follow best practices + + if isinstance(tokenizer, str): + tok = tokenizer.strip().lower() + if tok in {"13a", "tokenizer_13a", "mteval-v13a"}: + tokenizer = Tokenizer13a() + else: + raise ValueError( + f"Unknown tokenizer '{tokenizer}'. Supported string tokenizers: '13a'. Otherwise pass a callable." + ) + if isinstance(references[0], str): references = [[ref] for ref in references]