diff --git a/README.md b/README.md index 4e99b71..2b1166c 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,7 @@ scorer = NMTScorer("small100", device="cuda:0") # Enable faster inference on GP scorer = NMTScorer("m2m100_418M", device="cuda:0") scorer = NMTScorer("m2m100_1.2B", device="cuda:0") scorer = NMTScorer("prism", device="cuda:0") +scorer = NMTScorer("nllb-200-distilled-600M", device="cuda:0") # This model uses BCP-47 language codes ``` **Which model should I choose?** @@ -160,6 +161,9 @@ See [experiments/README.md](experiments/README.md) ## Changelog +- v0.3.4 (to be released) + - Include NLLB models ([Costa-jussà et al., 2022](https://arxiv.org/abs/2207.04672)): [`nllb-200-1.3B`](https://huggingface.co/facebook/nllb-200-1.3B), [`nllb-200-3.3B`](https://huggingface.co/facebook/nllb-200-3.3B), [`nllb-200-distilled-600M`](https://huggingface.co/facebook/nllb-200-distilled-600M), [`nllb-200-distilled-1.3B`](https://huggingface.co/facebook/nllb-200-distilled-1.3B). Note that the models use [BCP-47 language codes](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200). + - v0.3.3 - Update minimum required Python version to 3.8 - Require transformers<4.34 to ensure compatibility for `small100` model diff --git a/src/nmtscore/models/__init__.py b/src/nmtscore/models/__init__.py index c148f99..0d2799e 100644 --- a/src/nmtscore/models/__init__.py +++ b/src/nmtscore/models/__init__.py @@ -183,6 +183,9 @@ def _set_src_lang(self, src_lang: str): def _set_tgt_lang(self, tgt_lang: str): raise NotImplementedError + def _validate_lang_code(self, lang_code: str): + pass + def _translate(self, source_sentences: List[str], return_score: bool = False, @@ -237,6 +240,18 @@ def load_translation_model(name: str, **kwargs) -> TranslationModel: "`pip install nmtscore[prism]`") from nmtscore.models.prism import PrismModel translation_model = PrismModel(**kwargs) + elif name == "nllb-200-1.3B": + from nmtscore.models.nllb import NLLBModel + translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-1.3B", **kwargs) + elif name == "nllb-200-3.3B": + from nmtscore.models.nllb import NLLBModel + translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-3.3B", **kwargs) + elif name == "nllb-200-distilled-600M": + from nmtscore.models.nllb import NLLBModel + translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-distilled-600M", **kwargs) + elif name == "nllb-200-distilled-1.3B": + from nmtscore.models.nllb import NLLBModel + translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-distilled-1.3B", **kwargs) else: raise NotImplementedError return translation_model diff --git a/src/nmtscore/models/m2m100.py b/src/nmtscore/models/m2m100.py index 2de1faa..0b633ea 100644 --- a/src/nmtscore/models/m2m100.py +++ b/src/nmtscore/models/m2m100.py @@ -3,7 +3,7 @@ import torch from tqdm import tqdm -from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, TranslationPipeline +from transformers import M2M100ForConditionalGeneration, AutoTokenizer, TranslationPipeline from transformers.file_utils import PaddingStrategy from transformers.models.m2m_100.modeling_m2m_100 import shift_tokens_right @@ -40,7 +40,7 @@ def __str__(self): return self.model_name_or_path def _load_tokenizer(self): - return M2M100Tokenizer.from_pretrained(self.model_name_or_path) + return AutoTokenizer.from_pretrained(self.model_name_or_path) def _load_model(self): return M2M100ForConditionalGeneration.from_pretrained(self.model_name_or_path) @@ -50,13 +50,21 @@ def requires_src_lang(self) -> bool: return True def _set_src_lang(self, src_lang: str): + self._validate_lang_code(src_lang) self.src_lang = src_lang self.tokenizer.src_lang = src_lang def _set_tgt_lang(self, tgt_lang: str): + self._validate_lang_code(tgt_lang) self.tgt_lang = tgt_lang self.tokenizer.tgt_lang = tgt_lang + def _validate_lang_code(self, lang_code: str): + from transformers.models.m2m_100.tokenization_m2m_100 import FAIRSEQ_LANGUAGE_CODES + if lang_code not in FAIRSEQ_LANGUAGE_CODES["m2m100"]: + raise ValueError(f"{lang_code} is not a valid language code for {self}. " + f"Valid language codes are: {FAIRSEQ_LANGUAGE_CODES['m2m100']}") + @torch.no_grad() def _translate(self, source_sentences: List[str], diff --git a/src/nmtscore/models/nllb.py b/src/nmtscore/models/nllb.py new file mode 100644 index 0000000..452c946 --- /dev/null +++ b/src/nmtscore/models/nllb.py @@ -0,0 +1,11 @@ + +from nmtscore.models.m2m100 import M2M100Model + + +class NLLBModel(M2M100Model): + + def _validate_lang_code(self, lang_code: str): + from transformers.models.nllb.tokenization_nllb import FAIRSEQ_LANGUAGE_CODES + if lang_code not in FAIRSEQ_LANGUAGE_CODES: + raise ValueError(f"{lang_code} is not a valid language code for {self}. " + f"Valid language codes are: {FAIRSEQ_LANGUAGE_CODES}") diff --git a/tests/test_nmt_models.py b/tests/test_nmt_models.py index d21baf7..ceae306 100644 --- a/tests/test_nmt_models.py +++ b/tests/test_nmt_models.py @@ -10,15 +10,23 @@ class NMTModelTestCase(TestCase): @classmethod def setUpClass(cls) -> None: raise NotImplementedError + + @property + def lang_code_de(self): + return "de" + + @property + def lang_code_en(self): + return "en" def test_translate(self): - self.assertIn(self.model.translate("de", "This is a test.", src_lang="en"), { + self.assertIn(self.model.translate(self.lang_code_de, "This is a test.", src_lang=self.lang_code_en), { "Dies ist ein Test.", "Das ist ein Test.", }) def test_translate_score(self): - translation, score = self.model.translate("de", "This is a test.", return_score=True, src_lang="en") + translation, score = self.model.translate(self.lang_code_de, "This is a test.", return_score=True, src_lang=self.lang_code_en) self.assertIn(translation, { "Dies ist ein Test.", "Das ist ein Test.", @@ -27,10 +35,10 @@ def test_translate_score(self): return self.assertGreaterEqual(score, 0) self.assertLessEqual(score, 1) - self.assertAlmostEqual(score, self.model.score("de", "This is a test.", translation, src_lang="en"), places=5) + self.assertAlmostEqual(score, self.model.score(self.lang_code_de, "This is a test.", translation, src_lang=self.lang_code_en), places=5) def test_translate_batched(self): - translations = self.model.translate("de", 8 * ["This is a test."], src_lang="en") + translations = self.model.translate(self.lang_code_de, 8 * ["This is a test."], src_lang=self.lang_code_en) self.assertEqual(8, len(translations)) self.assertEqual(1, len(set(translations))) self.assertIn(translations[0], { @@ -40,23 +48,23 @@ def test_translate_batched(self): def test_score(self): scores = self.model.score( - "de", - src_lang = "en", + self.lang_code_de, + src_lang = self.lang_code_en, source_sentences=(2 * ["This is a test."]), hypothesis_sentences=(["Dies ist ein Test.", "Diese Übersetzung ist komplett falsch."]), ) self.assertIsInstance(scores[0], float) self.assertIsInstance(scores[1], float) scores = self.model.score( - "de", - src_lang = "en", + self.lang_code_de, + src_lang = self.lang_code_en, source_sentences=(2 * ["This is a test."]), hypothesis_sentences=(["Diese Übersetzung ist komplett falsch.", "Dies ist ein Test."]), ) self.assertLess(scores[0], scores[1]) scores = self.model.score( - "de", - src_lang = "en", + self.lang_code_de, + src_lang = self.lang_code_en, source_sentences=(2 * ["This is a test."]), hypothesis_sentences=(2 * ["Dies ist ein Test."]), ) @@ -64,8 +72,8 @@ def test_score(self): def test_score_batched(self): scores = self.model.score( - "de", - src_lang = "en", + self.lang_code_de, + src_lang = self.lang_code_en, source_sentences=(4 * ["This is a test."]), hypothesis_sentences=(["Diese Übersetzung ist komplett falsch", "Dies ist ein Test.", "Dies ist ein Test.", "Dies ist ein Test."]), batch_size=2, @@ -75,8 +83,8 @@ def test_score_batched(self): self.assertAlmostEqual(scores[3], scores[1], places=4) scores = self.model.score( - "de", - src_lang = "en", + self.lang_code_de, + src_lang = self.lang_code_en, source_sentences=(["This is a test.", "A translation that is completely wrong.", "This is a test.", "This is a test."]), hypothesis_sentences=(4 * ["Dies ist ein Test."]), batch_size=2, @@ -86,8 +94,8 @@ def test_score_batched(self): self.assertAlmostEqual(scores[3], scores[0], places=4) scores = self.model.score( - "de", - src_lang = "en", + self.lang_code_de, + src_lang = self.lang_code_en, source_sentences=(4 * ["This is a test."]), hypothesis_sentences=(["Dies ist ein Test.", "Dies ist ein Test.", ".", "Dies ist ein Test."]), batch_size=2, @@ -97,8 +105,8 @@ def test_score_batched(self): self.assertAlmostEqual(scores[3], scores[0], places=4) scores = self.model.score( - "de", - src_lang = "en", + self.lang_code_de, + src_lang = self.lang_code_en, source_sentences=(["This is a test.", "This is a test.", "This is a test.", "A translation that is completely wrong."]), hypothesis_sentences=(4 * ["Dies ist ein Test."]), batch_size=2, @@ -108,10 +116,10 @@ def test_score_batched(self): self.assertLess(scores[3], scores[0]) def test_translate_long_input(self): - self.model.translate("de", 100 * "This is a test. ", src_lang="en") + self.model.translate(self.lang_code_de, 100 * "This is a test. ", src_lang=self.lang_code_en) def test_score_long_input(self): - self.model.score("de", 100 * "This is a test. ", 100 * "Dies ist ein Test. ", src_lang="en") + self.model.score(self.lang_code_de, 100 * "This is a test. ", 100 * "Dies ist ein Test. ", src_lang=self.lang_code_en) @unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Slow") @@ -145,5 +153,37 @@ def setUpClass(cls) -> None: cls.model = load_translation_model("prism") +@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Slow") +class SmallDistilledNLLB200TestCase(NMTModelTestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.model = load_translation_model("nllb-200-distilled-600M") + + @property + def lang_code_de(self): + return "deu_Latn" + + @property + def lang_code_en(self): + return "eng_Latn" + + +@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Slow") +class SmallNLLB200TestCase(NMTModelTestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.model = load_translation_model("nllb-200-1.3B") + + @property + def lang_code_de(self): + return "deu_Latn" + + @property + def lang_code_en(self): + return "eng_Latn" + + # https://stackoverflow.com/a/43353680/3902795 del NMTModelTestCase