Skip to content

Commit

Permalink
Add NLLB models (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
jvamvas authored Dec 11, 2023
1 parent 212d526 commit 4bfba48
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 22 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ scorer = NMTScorer("small100", device="cuda:0") # Enable faster inference on GP
scorer = NMTScorer("m2m100_418M", device="cuda:0")
scorer = NMTScorer("m2m100_1.2B", device="cuda:0")
scorer = NMTScorer("prism", device="cuda:0")
scorer = NMTScorer("nllb-200-distilled-600M", device="cuda:0") # This model uses BCP-47 language codes
```

**Which model should I choose?**
Expand Down Expand Up @@ -160,6 +161,9 @@ See [experiments/README.md](experiments/README.md)

## Changelog

- v0.3.4 (to be released)
- Include NLLB models ([Costa-jussà et al., 2022](https://arxiv.org/abs/2207.04672)): [`nllb-200-1.3B`](https://huggingface.co/facebook/nllb-200-1.3B), [`nllb-200-3.3B`](https://huggingface.co/facebook/nllb-200-3.3B), [`nllb-200-distilled-600M`](https://huggingface.co/facebook/nllb-200-distilled-600M), [`nllb-200-distilled-1.3B`](https://huggingface.co/facebook/nllb-200-distilled-1.3B). Note that the models use [BCP-47 language codes](https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200).

- v0.3.3
- Update minimum required Python version to 3.8
- Require transformers<4.34 to ensure compatibility for `small100` model
Expand Down
15 changes: 15 additions & 0 deletions src/nmtscore/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def _set_src_lang(self, src_lang: str):
def _set_tgt_lang(self, tgt_lang: str):
raise NotImplementedError

def _validate_lang_code(self, lang_code: str):
pass

def _translate(self,
source_sentences: List[str],
return_score: bool = False,
Expand Down Expand Up @@ -237,6 +240,18 @@ def load_translation_model(name: str, **kwargs) -> TranslationModel:
"`pip install nmtscore[prism]`")
from nmtscore.models.prism import PrismModel
translation_model = PrismModel(**kwargs)
elif name == "nllb-200-1.3B":
from nmtscore.models.nllb import NLLBModel
translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-1.3B", **kwargs)
elif name == "nllb-200-3.3B":
from nmtscore.models.nllb import NLLBModel
translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-3.3B", **kwargs)
elif name == "nllb-200-distilled-600M":
from nmtscore.models.nllb import NLLBModel
translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-distilled-600M", **kwargs)
elif name == "nllb-200-distilled-1.3B":
from nmtscore.models.nllb import NLLBModel
translation_model = NLLBModel(model_name_or_path="facebook/nllb-200-distilled-1.3B", **kwargs)
else:
raise NotImplementedError
return translation_model
12 changes: 10 additions & 2 deletions src/nmtscore/models/m2m100.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from tqdm import tqdm
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, TranslationPipeline
from transformers import M2M100ForConditionalGeneration, AutoTokenizer, TranslationPipeline
from transformers.file_utils import PaddingStrategy
from transformers.models.m2m_100.modeling_m2m_100 import shift_tokens_right

Expand Down Expand Up @@ -40,7 +40,7 @@ def __str__(self):
return self.model_name_or_path

def _load_tokenizer(self):
return M2M100Tokenizer.from_pretrained(self.model_name_or_path)
return AutoTokenizer.from_pretrained(self.model_name_or_path)

def _load_model(self):
return M2M100ForConditionalGeneration.from_pretrained(self.model_name_or_path)
Expand All @@ -50,13 +50,21 @@ def requires_src_lang(self) -> bool:
return True

def _set_src_lang(self, src_lang: str):
self._validate_lang_code(src_lang)
self.src_lang = src_lang
self.tokenizer.src_lang = src_lang

def _set_tgt_lang(self, tgt_lang: str):
self._validate_lang_code(tgt_lang)
self.tgt_lang = tgt_lang
self.tokenizer.tgt_lang = tgt_lang

def _validate_lang_code(self, lang_code: str):
from transformers.models.m2m_100.tokenization_m2m_100 import FAIRSEQ_LANGUAGE_CODES
if lang_code not in FAIRSEQ_LANGUAGE_CODES["m2m100"]:
raise ValueError(f"{lang_code} is not a valid language code for {self}. "
f"Valid language codes are: {FAIRSEQ_LANGUAGE_CODES['m2m100']}")

@torch.no_grad()
def _translate(self,
source_sentences: List[str],
Expand Down
11 changes: 11 additions & 0 deletions src/nmtscore/models/nllb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

from nmtscore.models.m2m100 import M2M100Model


class NLLBModel(M2M100Model):

def _validate_lang_code(self, lang_code: str):
from transformers.models.nllb.tokenization_nllb import FAIRSEQ_LANGUAGE_CODES
if lang_code not in FAIRSEQ_LANGUAGE_CODES:
raise ValueError(f"{lang_code} is not a valid language code for {self}. "
f"Valid language codes are: {FAIRSEQ_LANGUAGE_CODES}")
80 changes: 60 additions & 20 deletions tests/test_nmt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,23 @@ class NMTModelTestCase(TestCase):
@classmethod
def setUpClass(cls) -> None:
raise NotImplementedError

@property
def lang_code_de(self):
return "de"

@property
def lang_code_en(self):
return "en"

def test_translate(self):
self.assertIn(self.model.translate("de", "This is a test.", src_lang="en"), {
self.assertIn(self.model.translate(self.lang_code_de, "This is a test.", src_lang=self.lang_code_en), {
"Dies ist ein Test.",
"Das ist ein Test.",
})

def test_translate_score(self):
translation, score = self.model.translate("de", "This is a test.", return_score=True, src_lang="en")
translation, score = self.model.translate(self.lang_code_de, "This is a test.", return_score=True, src_lang=self.lang_code_en)
self.assertIn(translation, {
"Dies ist ein Test.",
"Das ist ein Test.",
Expand All @@ -27,10 +35,10 @@ def test_translate_score(self):
return
self.assertGreaterEqual(score, 0)
self.assertLessEqual(score, 1)
self.assertAlmostEqual(score, self.model.score("de", "This is a test.", translation, src_lang="en"), places=5)
self.assertAlmostEqual(score, self.model.score(self.lang_code_de, "This is a test.", translation, src_lang=self.lang_code_en), places=5)

def test_translate_batched(self):
translations = self.model.translate("de", 8 * ["This is a test."], src_lang="en")
translations = self.model.translate(self.lang_code_de, 8 * ["This is a test."], src_lang=self.lang_code_en)
self.assertEqual(8, len(translations))
self.assertEqual(1, len(set(translations)))
self.assertIn(translations[0], {
Expand All @@ -40,32 +48,32 @@ def test_translate_batched(self):

def test_score(self):
scores = self.model.score(
"de",
src_lang = "en",
self.lang_code_de,
src_lang = self.lang_code_en,
source_sentences=(2 * ["This is a test."]),
hypothesis_sentences=(["Dies ist ein Test.", "Diese Übersetzung ist komplett falsch."]),
)
self.assertIsInstance(scores[0], float)
self.assertIsInstance(scores[1], float)
scores = self.model.score(
"de",
src_lang = "en",
self.lang_code_de,
src_lang = self.lang_code_en,
source_sentences=(2 * ["This is a test."]),
hypothesis_sentences=(["Diese Übersetzung ist komplett falsch.", "Dies ist ein Test."]),
)
self.assertLess(scores[0], scores[1])
scores = self.model.score(
"de",
src_lang = "en",
self.lang_code_de,
src_lang = self.lang_code_en,
source_sentences=(2 * ["This is a test."]),
hypothesis_sentences=(2 * ["Dies ist ein Test."]),
)
self.assertAlmostEqual(scores[0], scores[1], places=4)

def test_score_batched(self):
scores = self.model.score(
"de",
src_lang = "en",
self.lang_code_de,
src_lang = self.lang_code_en,
source_sentences=(4 * ["This is a test."]),
hypothesis_sentences=(["Diese Übersetzung ist komplett falsch", "Dies ist ein Test.", "Dies ist ein Test.", "Dies ist ein Test."]),
batch_size=2,
Expand All @@ -75,8 +83,8 @@ def test_score_batched(self):
self.assertAlmostEqual(scores[3], scores[1], places=4)

scores = self.model.score(
"de",
src_lang = "en",
self.lang_code_de,
src_lang = self.lang_code_en,
source_sentences=(["This is a test.", "A translation that is completely wrong.", "This is a test.", "This is a test."]),
hypothesis_sentences=(4 * ["Dies ist ein Test."]),
batch_size=2,
Expand All @@ -86,8 +94,8 @@ def test_score_batched(self):
self.assertAlmostEqual(scores[3], scores[0], places=4)

scores = self.model.score(
"de",
src_lang = "en",
self.lang_code_de,
src_lang = self.lang_code_en,
source_sentences=(4 * ["This is a test."]),
hypothesis_sentences=(["Dies ist ein Test.", "Dies ist ein Test.", ".", "Dies ist ein Test."]),
batch_size=2,
Expand All @@ -97,8 +105,8 @@ def test_score_batched(self):
self.assertAlmostEqual(scores[3], scores[0], places=4)

scores = self.model.score(
"de",
src_lang = "en",
self.lang_code_de,
src_lang = self.lang_code_en,
source_sentences=(["This is a test.", "This is a test.", "This is a test.", "A translation that is completely wrong."]),
hypothesis_sentences=(4 * ["Dies ist ein Test."]),
batch_size=2,
Expand All @@ -108,10 +116,10 @@ def test_score_batched(self):
self.assertLess(scores[3], scores[0])

def test_translate_long_input(self):
self.model.translate("de", 100 * "This is a test. ", src_lang="en")
self.model.translate(self.lang_code_de, 100 * "This is a test. ", src_lang=self.lang_code_en)

def test_score_long_input(self):
self.model.score("de", 100 * "This is a test. ", 100 * "Dies ist ein Test. ", src_lang="en")
self.model.score(self.lang_code_de, 100 * "This is a test. ", 100 * "Dies ist ein Test. ", src_lang=self.lang_code_en)


@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Slow")
Expand Down Expand Up @@ -145,5 +153,37 @@ def setUpClass(cls) -> None:
cls.model = load_translation_model("prism")


@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Slow")
class SmallDistilledNLLB200TestCase(NMTModelTestCase):

@classmethod
def setUpClass(cls) -> None:
cls.model = load_translation_model("nllb-200-distilled-600M")

@property
def lang_code_de(self):
return "deu_Latn"

@property
def lang_code_en(self):
return "eng_Latn"


@unittest.skipIf(os.getenv("SKIP_SLOW_TESTS", False), "Slow")
class SmallNLLB200TestCase(NMTModelTestCase):

@classmethod
def setUpClass(cls) -> None:
cls.model = load_translation_model("nllb-200-1.3B")

@property
def lang_code_de(self):
return "deu_Latn"

@property
def lang_code_en(self):
return "eng_Latn"


# https://stackoverflow.com/a/43353680/3902795
del NMTModelTestCase

0 comments on commit 4bfba48

Please sign in to comment.