From 55aeebff0096ae12e5ab44e65b441b2e55e7babe Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Fri, 14 Oct 2022 17:05:30 +0200 Subject: [PATCH 1/3] Bugfix: Provide source language to m2m100 models (#1) --- README.md | 4 +-- experiments/metrics/benchmark_metrics.py | 12 +++++++ experiments/metrics/nmtscore_metrics.py | 20 +++++++---- src/nmtscore/models/__init__.py | 38 +++++++++++++++++--- src/nmtscore/models/m2m100.py | 8 +++++ src/nmtscore/models/prism.py | 4 +++ src/nmtscore/scorer.py | 44 ++++++++++++++++++++---- tests/test_cache.py | 4 +++ tests/test_nmt_models.py | 19 ++++++---- tests/test_readme.py | 6 ++-- 10 files changed, 130 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 8a853ab..ea83aa2 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ This library currently supports three NMT models: - [m2m100_418M](https://huggingface.co/facebook/m2m100_418M) and [m2m100_1.2B](https://huggingface.co/facebook/m2m100_1.2B) by [Fan et al. (2021)](https://www.jmlr.org/papers/volume22/20-1307/) - [Prism](https://github.com/thompsonb/prism) by [Thompson and Post (2020)](https://aclanthology.org/2020.emnlp-main.8/) -By default, the leanest model (m2m100_418M) is loaded. The main results in the paper are based on the Prism model. +By default, the leanest model (m2m100_418M) is loaded. The main results in the paper are based on the Prism model, which has some extra requirements (see "Installation"), but is recommended due to its higher accuracy. ```python scorer = NMTScorer("m2m100_418M", device=None) # default @@ -118,7 +118,7 @@ The NMT models also provide a direct interface for translating and scoring. ```python from nmtscore.models import load_translation_model -model = load_translation_model("m2m100_418M") +model = load_translation_model("prism") model.translate("de", ["This is a test."]) # ["Das ist ein Test."] diff --git a/experiments/metrics/benchmark_metrics.py b/experiments/metrics/benchmark_metrics.py index 79a65b8..81c6702 100644 --- a/experiments/metrics/benchmark_metrics.py +++ b/experiments/metrics/benchmark_metrics.py @@ -101,6 +101,8 @@ def get_paraphrase_metrics(device=None) -> List[BenchmarkMetric]: title="Translation_Cross-Likelihood", metric_names=["nmtscore-cross"], load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric( + a_lang, + b_lang, scorer=NMTScorer("prism", device=device), both_directions=True, translate_kwargs=NMT_TRANSLATE_KWARGS, @@ -166,6 +168,8 @@ def get_nlg_evaluation_metrics(device=None) -> List[BenchmarkMetric]: title="Translation_Cross-Likelihood", metric_names=["nmtscore-cross", "nmtscore-cross-hyp|ref", "nmtscore-cross-ref|hyp"], load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric( + a_lang, + b_lang, scorer=NMTScorer("prism", device=device), both_directions=True, translate_kwargs=NMT_TRANSLATE_KWARGS, @@ -208,6 +212,8 @@ def get_paraphrase_metrics_m2m100(device=None) -> List[BenchmarkMetric]: title="Translation_Cross-Likelihood (m2m100_418M)", metric_names=["nmtscore-cross"], load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric( + a_lang, + b_lang, scorer=NMTScorer("m2m100_418M", device=device), both_directions=True, translate_kwargs={**NMT_TRANSLATE_KWARGS, **{"batch_size": batch_size}}, @@ -242,6 +248,8 @@ def get_paraphrase_metrics_m2m100(device=None) -> List[BenchmarkMetric]: title="Translation_Cross-Likelihood (m2m100_1.2B)", metric_names=["nmtscore-cross"], load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric( + a_lang, + b_lang, scorer=NMTScorer("m2m100_1.2B", device=device), both_directions=True, translate_kwargs={**NMT_TRANSLATE_KWARGS, **{"batch_size": batch_size}}, @@ -285,6 +293,8 @@ def get_normalization_ablation_metrics(device=None) -> List[BenchmarkMetric]: title="Translation_Cross-Likelihood (normalized)", metric_names=["nmtscore-cross"], load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric( + a_lang, + b_lang, scorer=NMTScorer("prism", device=device), normalize=True, both_directions=True, @@ -322,6 +332,8 @@ def get_normalization_ablation_metrics(device=None) -> List[BenchmarkMetric]: title="Translation_Cross-Likelihood (unnormalized)", metric_names=["nmtscore-cross"], load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric( + a_lang, + b_lang, scorer=NMTScorer("prism", device=device), normalize=False, both_directions=True, diff --git a/experiments/metrics/nmtscore_metrics.py b/experiments/metrics/nmtscore_metrics.py index 8da60cf..a7a9bd1 100644 --- a/experiments/metrics/nmtscore_metrics.py +++ b/experiments/metrics/nmtscore_metrics.py @@ -12,11 +12,15 @@ class NMTScoreMetric(ReferenceBasedMetric): def __init__(self, + summaries_lang: Optional[str] = None, + references_lang: Optional[str] = None, scorer: Union[NMTScorer, str] = "m2m100_418M", normalize: bool = True, both_directions: bool = True, ): super().__init__() + self.summaries_lang = summaries_lang + self.references_lang = references_lang if isinstance(scorer, str): self.scorer = NMTScorer(scorer) else: @@ -108,9 +112,7 @@ def __init__(self, ): if both_directions: assert references_lang is not None - super().__init__(scorer, normalize, both_directions) - self.summaries_lang = summaries_lang - self.references_lang = references_lang + super().__init__(summaries_lang, references_lang, scorer, normalize, both_directions) self.score_kwargs = score_kwargs @property @@ -156,9 +158,7 @@ def __init__(self, ): if both_directions: assert references_lang is not None - super().__init__(scorer, normalize, both_directions) - self.summaries_lang = summaries_lang - self.references_lang = references_lang + super().__init__(summaries_lang, references_lang, scorer, normalize, both_directions) self.pivot_lang = pivot_lang self.translate_kwargs = translate_kwargs self.score_kwargs = score_kwargs @@ -200,6 +200,8 @@ class CrossLikelihoodNMTScoreMetric(NMTScoreMetric): name = "nmtscore-cross" def __init__(self, + summaries_lang: Optional[str] = None, + references_lang: Optional[str] = None, tgt_lang: str = "en", scorer: Union[NMTScorer, str] = "m2m100_418M", normalize: bool = True, @@ -207,7 +209,7 @@ def __init__(self, translate_kwargs: dict = None, score_kwargs: dict = None, ): - super().__init__(scorer, normalize, both_directions) + super().__init__(summaries_lang, references_lang, scorer, normalize, both_directions) self.tgt_lang = tgt_lang self.translate_kwargs = translate_kwargs self.score_kwargs = score_kwargs @@ -221,6 +223,8 @@ def _score_summaries_given_references(self, summaries: List[str], references: Li return self.scorer.score_cross_likelihood( summaries, references, + self.summaries_lang, + self.references_lang, tgt_lang=self.tgt_lang, normalize=self.normalize, both_directions=False, @@ -232,6 +236,8 @@ def _score_references_given_summaries(self, summaries: List[str], references: Li return self.scorer.score_cross_likelihood( references, summaries, + self.references_lang, + self.summaries_lang, tgt_lang=self.tgt_lang, normalize=self.normalize, both_directions=False, diff --git a/src/nmtscore/models/__init__.py b/src/nmtscore/models/__init__.py index 383c0b5..7ae37ec 100644 --- a/src/nmtscore/models/__init__.py +++ b/src/nmtscore/models/__init__.py @@ -1,5 +1,7 @@ import json import os +import warnings + from sqlitedict import SqliteDict from pathlib import Path from typing import List, Union, Tuple @@ -13,6 +15,7 @@ def __str__(self): def translate(self, tgt_lang: str, source_sentences: Union[str, List[str]], + src_lang: str = None, return_score: bool = False, batch_size: int = 8, use_cache: bool = False, @@ -21,6 +24,7 @@ def translate(self, """ :param tgt_lang: Language code of the target language :param source_sentences: A sentence or list of sentences + :param src_lang: Language code of the source language (not needed for some multilingual models) :param return score: If true, return a tuple where the second element is sequence-level score of the translation :param batch_size :param use_cache @@ -40,7 +44,8 @@ def translate(self, cached_translations_list = [] with self.load_cache() as cache: for source_sentence in source_sentences_list: - translation = cache.get(f"{tgt_lang}_translate{'_score' if return_score else ''}_{source_sentence}", None) + translation = cache.get(f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_" + f"translate{'_score' if return_score else ''}_{source_sentence}", None) cached_translations_list.append(translation) full_source_sentences_list = source_sentences_list source_sentences_list = [ @@ -50,6 +55,11 @@ def translate(self, ] self._set_tgt_lang(tgt_lang) + if self.requires_src_lang: + if src_lang is None: + warnings.warn(f"NMT model {self} requires the src language. Assuming 'en'; override with `src_lang`") + src_lang = "en" + self._set_src_lang(src_lang) translations_list = self._translate(source_sentences_list, return_score, batch_size, **kwargs) assert len(translations_list) == len(source_sentences_list) @@ -59,7 +69,9 @@ def translate(self, if cached_translation is not None: translations_list.insert(i, cached_translation) else: - cache_update[f"{tgt_lang}_translate{'_score' if return_score else ''}_{full_source_sentences_list[i]}"] = translations_list[i] + cache_update[f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_" \ + f"translate{'_score' if return_score else ''}_" \ + f"{full_source_sentences_list[i]}"] = translations_list[i] if cache_update: with self.load_cache() as cache: cache.update(cache_update) @@ -75,6 +87,7 @@ def score(self, tgt_lang: str, source_sentences: Union[str, List[str]], hypothesis_sentences: Union[str, List[str]], + src_lang: str = None, batch_size: int = 8, use_cache: bool = False, **kwargs, @@ -83,6 +96,7 @@ def score(self, :param tgt_lang: Language code of the target language :param source_sentences: A sentence or list of sentences :param hypothesis_sentences: A sentence or list of sentences + :param src_lang: Language code of the source language (not needed for some multilingual models) :param batch_size :param use_cache :param kwargs @@ -105,7 +119,8 @@ def score(self, cached_scores_list = [] with self.load_cache() as cache: for source_sentence, hypothesis_sentence in zip(source_sentences_list, hypothesis_sentences_list): - score = cache.get(f"{tgt_lang}_score_{source_sentence}_{hypothesis_sentence}", None) + score = cache.get(f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_" + f"score_{source_sentence}_{hypothesis_sentence}", None) cached_scores_list.append(score) full_source_sentences_list = source_sentences_list source_sentences_list = [ @@ -121,6 +136,8 @@ def score(self, ] self._set_tgt_lang(tgt_lang) + if self.requires_src_lang: + self._set_src_lang(src_lang) scores_list = self._score(source_sentences_list, hypothesis_sentences_list, batch_size, **kwargs) assert len(scores_list) == len(source_sentences_list) @@ -130,8 +147,9 @@ def score(self, if cached_score is not None: scores_list.insert(i, cached_score) else: - cache_update[f"{tgt_lang}_score_{full_source_sentences_list[i]}_{full_hypothesis_sentences_list[i]}"] = \ - scores_list[i] + cache_update[f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_" \ + f"score_{full_source_sentences_list[i]}_" \ + f"{full_hypothesis_sentences_list[i]}"] = scores_list[i] if cache_update: with self.load_cache() as cache: cache.update(cache_update) @@ -143,6 +161,16 @@ def score(self, scores = scores_list return scores + @property + def requires_src_lang(self) -> bool: + """ + Boolean indicating whether the model requires the source language to be specified + """ + raise NotImplementedError + + def _set_src_lang(self, src_lang: str): + raise NotImplementedError + def _set_tgt_lang(self, tgt_lang: str): raise NotImplementedError diff --git a/src/nmtscore/models/m2m100.py b/src/nmtscore/models/m2m100.py index a40affa..79b805e 100644 --- a/src/nmtscore/models/m2m100.py +++ b/src/nmtscore/models/m2m100.py @@ -34,6 +34,14 @@ def __init__(self, def __str__(self): return self.model_name_or_path + @property + def requires_src_lang(self) -> bool: + return True + + def _set_src_lang(self, src_lang: str): + self.src_lang = src_lang + self.tokenizer.src_lang = src_lang + def _set_tgt_lang(self, tgt_lang: str): self.tgt_lang = tgt_lang self.tokenizer.tgt_lang = tgt_lang diff --git a/src/nmtscore/models/prism.py b/src/nmtscore/models/prism.py index eafbc9d..fe875a9 100644 --- a/src/nmtscore/models/prism.py +++ b/src/nmtscore/models/prism.py @@ -49,6 +49,10 @@ def __init__(self, def __str__(self): return "prism" + @property + def requires_src_lang(self) -> bool: + return False + def _set_tgt_lang(self, tgt_lang: str): assert tgt_lang in self.supported_languages self.tgt_lang = tgt_lang diff --git a/src/nmtscore/scorer.py b/src/nmtscore/scorer.py index 7be4438..dc3ac32 100644 --- a/src/nmtscore/scorer.py +++ b/src/nmtscore/scorer.py @@ -1,4 +1,5 @@ import logging +import warnings from typing import Union, List, Optional import numpy as np @@ -68,20 +69,26 @@ def score_direct(self, """ if both_directions: assert b_lang is not None + if self.model.requires_src_lang and b_lang is None: + warnings.warn(f"NMT model {self.model} requires the src language. Assuming {a_lang}; override with `b_lang`") + b_lang = a_lang + scores = self.model.score( + src_lang=b_lang, tgt_lang=a_lang, source_sentences=b, hypothesis_sentences=a, **(score_kwargs or {}), ) if normalize: - self_scores = self.score_direct(a, a, a_lang, b_lang=None, normalize=False, both_directions=False, score_kwargs=score_kwargs) + self_scores = self.score_direct(a, a, a_lang, b_lang=a_lang, normalize=False, both_directions=False, score_kwargs=score_kwargs) scores = np.array(scores) / np.array(self_scores) if both_directions: reverse_scores = self.score_direct(b, a, b_lang, a_lang, normalize=normalize, both_directions=False, score_kwargs=score_kwargs) scores = self._average_scores(scores, reverse_scores) if print_signature: - print(self._build_version_string("direct", normalized=normalize, both_directions=both_directions)) + print(self._build_version_string("direct", a_lang=a_lang, b_lang=b_lang, + normalized=normalize, both_directions=both_directions)) return scores def score_pivot(self, @@ -112,9 +119,14 @@ def score_pivot(self, """ if both_directions: assert b_lang is not None + if self.model.requires_src_lang and b_lang is None: + warnings.warn(f"NMT model {self.model} requires the src language. Assuming {a_lang}; override with `b_lang`") + b_lang = a_lang + if isinstance(a, list) and len(a) >= 10: logging.info(f"Translating to pivot language {pivot_lang} ...") translations = self.model.translate( + src_lang=b_lang, tgt_lang=pivot_lang, source_sentences=b, **(translate_kwargs or {}), @@ -122,13 +134,14 @@ def score_pivot(self, if isinstance(a, list) and len(a) >= 10: logging.info(f"Scoring sentences ...") scores = self.model.score( + src_lang=pivot_lang, tgt_lang=a_lang, source_sentences=translations, hypothesis_sentences=a, **(score_kwargs or {}), ) if normalize: - self_scores = self.score_pivot(a, a, a_lang, b_lang=None, pivot_lang=pivot_lang, normalize=False, both_directions=False, + self_scores = self.score_pivot(a, a, a_lang, b_lang=a_lang, pivot_lang=pivot_lang, normalize=False, both_directions=False, translate_kwargs=translate_kwargs, score_kwargs=score_kwargs) scores = np.array(scores) / np.array(self_scores) if both_directions: @@ -136,12 +149,15 @@ def score_pivot(self, translate_kwargs=translate_kwargs, score_kwargs=score_kwargs) scores = self._average_scores(scores, reverse_scores) if print_signature: - print(self._build_version_string("pivot", normalized=normalize, both_directions=both_directions, pivot_lang=pivot_lang)) + print(self._build_version_string("pivot", normalized=normalize, both_directions=both_directions, + pivot_lang=pivot_lang, a_lang=a_lang, b_lang=b_lang)) return scores def score_cross_likelihood(self, a: Union[str, List[str]], b: Union[str, List[str]], + a_lang: Optional[str] = None, + b_lang: Optional[str] = None, tgt_lang: str = "en", normalize: bool = True, both_directions: bool = True, @@ -156,6 +172,8 @@ def score_cross_likelihood(self, :param b: A sentence or list of sentences. If :param: both_directions is False this is the sentence that is translated :param tgt_lang: The language code of the target language (default: "en") + :param a_lang: The language code of A (default: None). Not needed for some multilingual models + :param b_lang: The language code of B (default: a_lang). Not needed for some multilingual models :param normalize: Apply a normalization to the similarity score (default: True) :param both_directions: Return the average of score(a, b) and score(b, a) (default: True) :param print_signature: Print a version signature for the metric (default: False) @@ -163,9 +181,15 @@ def score_cross_likelihood(self, :param score_kwargs :return: A float or list of floats """ + if self.model.requires_src_lang and (a_lang is None or b_lang is None): + warnings.warn(f"NMT model {self.model} requires the input languages. Assuming 'en' for unspecified languages; " + f"override with `a_lang` and `b_lang`") + a_lang = a_lang or "en" + b_lang = b_lang or a_lang if isinstance(a, list) and len(a) >= 10: logging.info(f"Translating to target language {tgt_lang} ...") translations_scores = self.model.translate( + src_lang=b_lang, tgt_lang=tgt_lang, source_sentences=b, return_score=True, @@ -175,6 +199,7 @@ def score_cross_likelihood(self, if isinstance(a, list) and len(a) >= 10: logging.info(f"Scoring sentences ...") scores = self.model.score( + src_lang=a_lang, tgt_lang=tgt_lang, source_sentences=a, hypothesis_sentences=translations, @@ -187,6 +212,7 @@ def score_cross_likelihood(self, no_scores_yet = translation_scores[0] is None if isinstance(translation_scores, list) else translation_scores is None if no_scores_yet: translation_scores = self.model.score( + src_lang=b_lang, tgt_lang=tgt_lang, source_sentences=b, hypothesis_sentences=translations, @@ -194,11 +220,13 @@ def score_cross_likelihood(self, ) scores = np.array(scores) / np.array(translation_scores) if both_directions: - reverse_scores = self.score_cross_likelihood(b, a, tgt_lang, normalize=normalize, both_directions=False, + reverse_scores = self.score_cross_likelihood(b, a, tgt_lang, a_lang=b_lang, b_lang=a_lang, + normalize=normalize, both_directions=False, translate_kwargs=translate_kwargs, score_kwargs=score_kwargs) scores = self._average_scores(scores, reverse_scores) if print_signature: - print(self._build_version_string("cross", normalized=normalize, both_directions=both_directions, tgt_lang=tgt_lang)) + print(self._build_version_string("cross", normalized=normalize, both_directions=both_directions, + tgt_lang=tgt_lang, a_lang=a_lang, b_lang=b_lang)) return scores def _average_scores(self, @@ -214,12 +242,16 @@ def _build_version_string(self, both_directions: bool, tgt_lang: str = None, pivot_lang: str = None, + a_lang: str = None, + b_lang: str = None, ) -> str: import nmtscore import transformers return f"NMTScore-{type}|" \ f"{f'tgt-lang:{tgt_lang}|' if tgt_lang is not None else ''}" \ f"{f'pivot-lang:{pivot_lang}|' if pivot_lang is not None else ''}" \ + f"{f'a-lang:{a_lang}|' if a_lang is not None else ''}" \ + f"{f'b-lang:{b_lang}|' if b_lang is not None else ''}" \ f"model:{self.model}|" \ f"{'normalized' if normalized else 'unnormalized'}|" \ f"{'both-directions' if both_directions else 'single-direction'}|" \ diff --git a/tests/test_cache.py b/tests/test_cache.py index 2b7388c..dc3fb3f 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -14,6 +14,10 @@ class CopyTranslationModel(TranslationModel): def __str__(self): return "copy-translation-model" + @property + def requires_src_lang(self) -> bool: + return False + def _set_tgt_lang(self, tgt_lang: str): pass diff --git a/tests/test_nmt_models.py b/tests/test_nmt_models.py index 7a16783..9d47753 100644 --- a/tests/test_nmt_models.py +++ b/tests/test_nmt_models.py @@ -12,13 +12,13 @@ def setUpClass(cls) -> None: raise NotImplementedError def test_translate(self): - self.assertIn(self.model.translate("de", "This is a test."), { + self.assertIn(self.model.translate("de", "This is a test.", src_lang="en"), { "Dies ist ein Test.", "Das ist ein Test.", }) def test_translate_score(self): - translation, score = self.model.translate("de", "This is a test.", return_score=True) + translation, score = self.model.translate("de", "This is a test.", return_score=True, src_lang="en") self.assertIn(translation, { "Dies ist ein Test.", "Das ist ein Test.", @@ -27,10 +27,10 @@ def test_translate_score(self): return self.assertGreaterEqual(score, 0) self.assertLessEqual(score, 1) - self.assertAlmostEqual(score, self.model.score("de", "This is a test.", translation), places=5) + self.assertAlmostEqual(score, self.model.score("de", "This is a test.", translation, src_lang="en"), places=5) def test_translate_batched(self): - translations = self.model.translate("de", 8 * ["This is a test."]) + translations = self.model.translate("de", 8 * ["This is a test."], src_lang="en") self.assertEqual(8, len(translations)) self.assertEqual(1, len(set(translations))) self.assertIn(translations[0], { @@ -41,6 +41,7 @@ def test_translate_batched(self): def test_score(self): scores = self.model.score( "de", + src_lang = "en", source_sentences=(2 * ["This is a test."]), hypothesis_sentences=(["Dies ist ein Test.", "Diese Übersetzung ist komplett falsch."]), ) @@ -48,12 +49,14 @@ def test_score(self): self.assertIsInstance(scores[1], float) scores = self.model.score( "de", + src_lang = "en", source_sentences=(2 * ["This is a test."]), hypothesis_sentences=(["Diese Übersetzung ist komplett falsch.", "Dies ist ein Test."]), ) self.assertLess(scores[0], scores[1]) scores = self.model.score( "de", + src_lang = "en", source_sentences=(2 * ["This is a test."]), hypothesis_sentences=(2 * ["Dies ist ein Test."]), ) @@ -62,6 +65,7 @@ def test_score(self): def test_score_batched(self): scores = self.model.score( "de", + src_lang = "en", source_sentences=(4 * ["This is a test."]), hypothesis_sentences=(["Diese Übersetzung ist komplett falsch", "Dies ist ein Test.", "Dies ist ein Test.", "Dies ist ein Test."]), batch_size=2, @@ -72,6 +76,7 @@ def test_score_batched(self): scores = self.model.score( "de", + src_lang = "en", source_sentences=(["This is a test.", "A translation that is completely wrong.", "This is a test.", "This is a test."]), hypothesis_sentences=(4 * ["Dies ist ein Test."]), batch_size=2, @@ -82,6 +87,7 @@ def test_score_batched(self): scores = self.model.score( "de", + src_lang = "en", source_sentences=(4 * ["This is a test."]), hypothesis_sentences=(["Dies ist ein Test.", "Dies ist ein Test.", ".", "Dies ist ein Test."]), batch_size=2, @@ -92,6 +98,7 @@ def test_score_batched(self): scores = self.model.score( "de", + src_lang = "en", source_sentences=(["This is a test.", "This is a test.", "This is a test.", "A translation that is completely wrong."]), hypothesis_sentences=(4 * ["Dies ist ein Test."]), batch_size=2, @@ -101,10 +108,10 @@ def test_score_batched(self): self.assertLess(scores[3], scores[0]) def test_translate_long_input(self): - self.model.translate("de", 100 * "This is a test. ") + self.model.translate("de", 100 * "This is a test. ", src_lang="en") def test_score_long_input(self): - self.model.score("de", 100 * "This is a test. ", 100 * "Dies ist ein Test. ") + self.model.score("de", 100 * "This is a test. ", 100 * "Dies ist ein Test. ", src_lang="en") class SmallM2M100TestCase(NMTModelTestCase): diff --git a/tests/test_readme.py b/tests/test_readme.py index cfcba1e..6441a6e 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -78,11 +78,11 @@ def test_version_signature(self, mock_stdout): a = "This is a sentence." b = "This is another sentence." score = scorer.score(a, b, print_signature=True) - self.assertIn("NMTScore-cross|tgt-lang:en|model:facebook/m2m100_418M|normalized|both-directions", mock_stdout.getvalue()) + self.assertIn("NMTScore-cross|tgt-lang:en|a-lang:en|b-lang:en|model:facebook/m2m100_418M|normalized|both-directions", mock_stdout.getvalue()) def test_nmt_models(self): model = load_translation_model("m2m100_418M") - translations = model.translate("de", ["This is a test."]) + translations = model.translate("de", ["This is a test."], src_lang="en") self.assertEqual(["Das ist ein Test."], translations) - scores = model.score("de", ["This is a test."], ["Das ist ein Test."]) + scores = model.score("de", ["This is a test."], ["Das ist ein Test."], src_lang="en") self.assertAlmostEqual(0.5148844122886658, scores[0], places=4) From 692d7bd0ed4fdab6ffcce334dc08d9c43fb51806 Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Fri, 14 Oct 2022 20:36:26 +0200 Subject: [PATCH 2/3] Pin fairseq version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 1cf3800..cee296e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ install_requires = [options.extras_require] prism = - fairseq + fairseq<=0.10.0 [options.packages.find] where = src From 3df06b8d43972ce4f4b0943a7cea669f949f3226 Mon Sep 17 00:00:00 2001 From: Jannis Vamvas Date: Tue, 18 Oct 2022 06:13:29 +0200 Subject: [PATCH 3/3] Fix argument order --- src/nmtscore/scorer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nmtscore/scorer.py b/src/nmtscore/scorer.py index dc3ac32..8939aae 100644 --- a/src/nmtscore/scorer.py +++ b/src/nmtscore/scorer.py @@ -81,7 +81,7 @@ def score_direct(self, **(score_kwargs or {}), ) if normalize: - self_scores = self.score_direct(a, a, a_lang, b_lang=a_lang, normalize=False, both_directions=False, score_kwargs=score_kwargs) + self_scores = self.score_direct(a, a, a_lang, a_lang, normalize=False, both_directions=False, score_kwargs=score_kwargs) scores = np.array(scores) / np.array(self_scores) if both_directions: reverse_scores = self.score_direct(b, a, b_lang, a_lang, normalize=normalize, both_directions=False, score_kwargs=score_kwargs) @@ -141,7 +141,7 @@ def score_pivot(self, **(score_kwargs or {}), ) if normalize: - self_scores = self.score_pivot(a, a, a_lang, b_lang=a_lang, pivot_lang=pivot_lang, normalize=False, both_directions=False, + self_scores = self.score_pivot(a, a, a_lang, a_lang, pivot_lang=pivot_lang, normalize=False, both_directions=False, translate_kwargs=translate_kwargs, score_kwargs=score_kwargs) scores = np.array(scores) / np.array(self_scores) if both_directions: @@ -220,7 +220,7 @@ def score_cross_likelihood(self, ) scores = np.array(scores) / np.array(translation_scores) if both_directions: - reverse_scores = self.score_cross_likelihood(b, a, tgt_lang, a_lang=b_lang, b_lang=a_lang, + reverse_scores = self.score_cross_likelihood(b, a, b_lang, a_lang, tgt_lang=tgt_lang, normalize=normalize, both_directions=False, translate_kwargs=translate_kwargs, score_kwargs=score_kwargs) scores = self._average_scores(scores, reverse_scores)