Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin_public/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
jvamvas committed Oct 18, 2022
2 parents 0d5ee9a + bc9da1e commit bd0ea31
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 29 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ This library currently supports three NMT models:
- [m2m100_418M](https://huggingface.co/facebook/m2m100_418M) and [m2m100_1.2B](https://huggingface.co/facebook/m2m100_1.2B) by [Fan et al. (2021)](https://www.jmlr.org/papers/volume22/20-1307/)
- [Prism](https://github.com/thompsonb/prism) by [Thompson and Post (2020)](https://aclanthology.org/2020.emnlp-main.8/)

By default, the leanest model (m2m100_418M) is loaded. The main results in the paper are based on the Prism model.
By default, the leanest model (m2m100_418M) is loaded. The main results in the paper are based on the Prism model, which has some extra requirements (see "Installation"), but is recommended due to its higher accuracy.

```python
scorer = NMTScorer("m2m100_418M", device=None) # default
Expand Down Expand Up @@ -118,7 +118,7 @@ The NMT models also provide a direct interface for translating and scoring.
```python
from nmtscore.models import load_translation_model

model = load_translation_model("m2m100_418M")
model = load_translation_model("prism")

model.translate("de", ["This is a test."])
# ["Das ist ein Test."]
Expand Down
12 changes: 12 additions & 0 deletions experiments/metrics/benchmark_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def get_paraphrase_metrics(device=None) -> List[BenchmarkMetric]:
title="Translation_Cross-Likelihood",
metric_names=["nmtscore-cross"],
load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric(
a_lang,
b_lang,
scorer=NMTScorer("prism", device=device),
both_directions=True,
translate_kwargs=NMT_TRANSLATE_KWARGS,
Expand Down Expand Up @@ -166,6 +168,8 @@ def get_nlg_evaluation_metrics(device=None) -> List[BenchmarkMetric]:
title="Translation_Cross-Likelihood",
metric_names=["nmtscore-cross", "nmtscore-cross-hyp|ref", "nmtscore-cross-ref|hyp"],
load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric(
a_lang,
b_lang,
scorer=NMTScorer("prism", device=device),
both_directions=True,
translate_kwargs=NMT_TRANSLATE_KWARGS,
Expand Down Expand Up @@ -208,6 +212,8 @@ def get_paraphrase_metrics_m2m100(device=None) -> List[BenchmarkMetric]:
title="Translation_Cross-Likelihood (m2m100_418M)",
metric_names=["nmtscore-cross"],
load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric(
a_lang,
b_lang,
scorer=NMTScorer("m2m100_418M", device=device),
both_directions=True,
translate_kwargs={**NMT_TRANSLATE_KWARGS, **{"batch_size": batch_size}},
Expand Down Expand Up @@ -242,6 +248,8 @@ def get_paraphrase_metrics_m2m100(device=None) -> List[BenchmarkMetric]:
title="Translation_Cross-Likelihood (m2m100_1.2B)",
metric_names=["nmtscore-cross"],
load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric(
a_lang,
b_lang,
scorer=NMTScorer("m2m100_1.2B", device=device),
both_directions=True,
translate_kwargs={**NMT_TRANSLATE_KWARGS, **{"batch_size": batch_size}},
Expand Down Expand Up @@ -285,6 +293,8 @@ def get_normalization_ablation_metrics(device=None) -> List[BenchmarkMetric]:
title="Translation_Cross-Likelihood (normalized)",
metric_names=["nmtscore-cross"],
load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric(
a_lang,
b_lang,
scorer=NMTScorer("prism", device=device),
normalize=True,
both_directions=True,
Expand Down Expand Up @@ -322,6 +332,8 @@ def get_normalization_ablation_metrics(device=None) -> List[BenchmarkMetric]:
title="Translation_Cross-Likelihood (unnormalized)",
metric_names=["nmtscore-cross"],
load_func=lambda a_lang, b_lang, device=device: CrossLikelihoodNMTScoreMetric(
a_lang,
b_lang,
scorer=NMTScorer("prism", device=device),
normalize=False,
both_directions=True,
Expand Down
20 changes: 13 additions & 7 deletions experiments/metrics/nmtscore_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
class NMTScoreMetric(ReferenceBasedMetric):

def __init__(self,
summaries_lang: Optional[str] = None,
references_lang: Optional[str] = None,
scorer: Union[NMTScorer, str] = "m2m100_418M",
normalize: bool = True,
both_directions: bool = True,
):
super().__init__()
self.summaries_lang = summaries_lang
self.references_lang = references_lang
if isinstance(scorer, str):
self.scorer = NMTScorer(scorer)
else:
Expand Down Expand Up @@ -108,9 +112,7 @@ def __init__(self,
):
if both_directions:
assert references_lang is not None
super().__init__(scorer, normalize, both_directions)
self.summaries_lang = summaries_lang
self.references_lang = references_lang
super().__init__(summaries_lang, references_lang, scorer, normalize, both_directions)
self.score_kwargs = score_kwargs

@property
Expand Down Expand Up @@ -156,9 +158,7 @@ def __init__(self,
):
if both_directions:
assert references_lang is not None
super().__init__(scorer, normalize, both_directions)
self.summaries_lang = summaries_lang
self.references_lang = references_lang
super().__init__(summaries_lang, references_lang, scorer, normalize, both_directions)
self.pivot_lang = pivot_lang
self.translate_kwargs = translate_kwargs
self.score_kwargs = score_kwargs
Expand Down Expand Up @@ -200,14 +200,16 @@ class CrossLikelihoodNMTScoreMetric(NMTScoreMetric):
name = "nmtscore-cross"

def __init__(self,
summaries_lang: Optional[str] = None,
references_lang: Optional[str] = None,
tgt_lang: str = "en",
scorer: Union[NMTScorer, str] = "m2m100_418M",
normalize: bool = True,
both_directions: bool = True,
translate_kwargs: dict = None,
score_kwargs: dict = None,
):
super().__init__(scorer, normalize, both_directions)
super().__init__(summaries_lang, references_lang, scorer, normalize, both_directions)
self.tgt_lang = tgt_lang
self.translate_kwargs = translate_kwargs
self.score_kwargs = score_kwargs
Expand All @@ -221,6 +223,8 @@ def _score_summaries_given_references(self, summaries: List[str], references: Li
return self.scorer.score_cross_likelihood(
summaries,
references,
self.summaries_lang,
self.references_lang,
tgt_lang=self.tgt_lang,
normalize=self.normalize,
both_directions=False,
Expand All @@ -232,6 +236,8 @@ def _score_references_given_summaries(self, summaries: List[str], references: Li
return self.scorer.score_cross_likelihood(
references,
summaries,
self.references_lang,
self.summaries_lang,
tgt_lang=self.tgt_lang,
normalize=self.normalize,
both_directions=False,
Expand Down
38 changes: 33 additions & 5 deletions src/nmtscore/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
import warnings

from sqlitedict import SqliteDict
from pathlib import Path
from typing import List, Union, Tuple
Expand All @@ -13,6 +15,7 @@ def __str__(self):
def translate(self,
tgt_lang: str,
source_sentences: Union[str, List[str]],
src_lang: str = None,
return_score: bool = False,
batch_size: int = 8,
use_cache: bool = False,
Expand All @@ -21,6 +24,7 @@ def translate(self,
"""
:param tgt_lang: Language code of the target language
:param source_sentences: A sentence or list of sentences
:param src_lang: Language code of the source language (not needed for some multilingual models)
:param return score: If true, return a tuple where the second element is sequence-level score of the translation
:param batch_size
:param use_cache
Expand All @@ -40,7 +44,8 @@ def translate(self,
cached_translations_list = []
with self.load_cache() as cache:
for source_sentence in source_sentences_list:
translation = cache.get(f"{tgt_lang}_translate{'_score' if return_score else ''}_{source_sentence}", None)
translation = cache.get(f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_"
f"translate{'_score' if return_score else ''}_{source_sentence}", None)
cached_translations_list.append(translation)
full_source_sentences_list = source_sentences_list
source_sentences_list = [
Expand All @@ -50,6 +55,11 @@ def translate(self,
]

self._set_tgt_lang(tgt_lang)
if self.requires_src_lang:
if src_lang is None:
warnings.warn(f"NMT model {self} requires the src language. Assuming 'en'; override with `src_lang`")
src_lang = "en"
self._set_src_lang(src_lang)
translations_list = self._translate(source_sentences_list, return_score, batch_size, **kwargs)
assert len(translations_list) == len(source_sentences_list)

Expand All @@ -59,7 +69,9 @@ def translate(self,
if cached_translation is not None:
translations_list.insert(i, cached_translation)
else:
cache_update[f"{tgt_lang}_translate{'_score' if return_score else ''}_{full_source_sentences_list[i]}"] = translations_list[i]
cache_update[f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_" \
f"translate{'_score' if return_score else ''}_" \
f"{full_source_sentences_list[i]}"] = translations_list[i]
if cache_update:
with self.load_cache() as cache:
cache.update(cache_update)
Expand All @@ -75,6 +87,7 @@ def score(self,
tgt_lang: str,
source_sentences: Union[str, List[str]],
hypothesis_sentences: Union[str, List[str]],
src_lang: str = None,
batch_size: int = 8,
use_cache: bool = False,
**kwargs,
Expand All @@ -83,6 +96,7 @@ def score(self,
:param tgt_lang: Language code of the target language
:param source_sentences: A sentence or list of sentences
:param hypothesis_sentences: A sentence or list of sentences
:param src_lang: Language code of the source language (not needed for some multilingual models)
:param batch_size
:param use_cache
:param kwargs
Expand All @@ -105,7 +119,8 @@ def score(self,
cached_scores_list = []
with self.load_cache() as cache:
for source_sentence, hypothesis_sentence in zip(source_sentences_list, hypothesis_sentences_list):
score = cache.get(f"{tgt_lang}_score_{source_sentence}_{hypothesis_sentence}", None)
score = cache.get(f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_"
f"score_{source_sentence}_{hypothesis_sentence}", None)
cached_scores_list.append(score)
full_source_sentences_list = source_sentences_list
source_sentences_list = [
Expand All @@ -121,6 +136,8 @@ def score(self,
]

self._set_tgt_lang(tgt_lang)
if self.requires_src_lang:
self._set_src_lang(src_lang)
scores_list = self._score(source_sentences_list, hypothesis_sentences_list, batch_size, **kwargs)
assert len(scores_list) == len(source_sentences_list)

Expand All @@ -130,8 +147,9 @@ def score(self,
if cached_score is not None:
scores_list.insert(i, cached_score)
else:
cache_update[f"{tgt_lang}_score_{full_source_sentences_list[i]}_{full_hypothesis_sentences_list[i]}"] = \
scores_list[i]
cache_update[f"{(src_lang + '_') if src_lang is not None else ''}{tgt_lang}_" \
f"score_{full_source_sentences_list[i]}_" \
f"{full_hypothesis_sentences_list[i]}"] = scores_list[i]
if cache_update:
with self.load_cache() as cache:
cache.update(cache_update)
Expand All @@ -143,6 +161,16 @@ def score(self,
scores = scores_list
return scores

@property
def requires_src_lang(self) -> bool:
"""
Boolean indicating whether the model requires the source language to be specified
"""
raise NotImplementedError

def _set_src_lang(self, src_lang: str):
raise NotImplementedError

def _set_tgt_lang(self, tgt_lang: str):
raise NotImplementedError

Expand Down
8 changes: 8 additions & 0 deletions src/nmtscore/models/m2m100.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def __init__(self,
def __str__(self):
return self.model_name_or_path

@property
def requires_src_lang(self) -> bool:
return True

def _set_src_lang(self, src_lang: str):
self.src_lang = src_lang
self.tokenizer.src_lang = src_lang

def _set_tgt_lang(self, tgt_lang: str):
self.tgt_lang = tgt_lang
self.tokenizer.tgt_lang = tgt_lang
Expand Down
4 changes: 4 additions & 0 deletions src/nmtscore/models/prism.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def __init__(self,
def __str__(self):
return "prism"

@property
def requires_src_lang(self) -> bool:
return False

def _set_tgt_lang(self, tgt_lang: str):
assert tgt_lang in self.supported_languages
self.tgt_lang = tgt_lang
Expand Down
Loading

0 comments on commit bd0ea31

Please sign in to comment.