diff --git a/CHANGELOG.md b/CHANGELOG.md index 878f6a088..525f4b6e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.14] + +### Added +- Added the implementation of Neural vocabulary selection to Sockeye as presented in our NAACL 2022 paper "The Devil is in the Details: On the Pitfalls of Vocabulary Selection in Neural Machine Translation" (Tobias Domhan, Eva Hasler, Ke Tran, Sony Trenous, Bill Byrne and Felix Hieber). + - To use NVS simply specify `--neural-vocab-selection` to `sockeye-train`. This will train a model with Neural Vocabulary Selection that is automatically used by `sockeye-translate`. If you want look at translations without vocabulary selection specify `--skip-nvs` as an argument to `sockeye-translate`. + ## [3.1.13] ### Added diff --git a/README.md b/README.md index 99561b12f..2d310a389 100644 --- a/README.md +++ b/README.md @@ -84,17 +84,18 @@ For more information about Sockeye, see our papers ([BibTeX](sockeye.bib)). ## Research with Sockeye Sockeye has been used for both academic and industrial research. A list of known publications that use Sockeye is shown below. -If you know more, please let us know or submit a pull request (last updated: April 2022). +If you know more, please let us know or submit a pull request (last updated: May 2022). ### 2022 * Weller-Di Marco, Marion, Matthias Huck, Alexander Fraser. "Modeling Target-Side Morphology in Neural Machine Translation: A Comparison of Strategies ". arXiv preprint arXiv:2203.13550 (2022) +* Tobias Domhan, Eva Hasler, Ke Tran, Sony Trenous, Bill Byrne and Felix Hieber. "The Devil is in the Details: On the Pitfalls of Vocabulary Selection in Neural Machine Translation". Proceedings of NAACL-HLT (2022) ### 2021 * Bergmanis, Toms, Mārcis Pinnis. "Facilitating Terminology Translation with Target Lemma Annotations". arXiv preprint arXiv:2101.10035 (2021) * Briakou, Eleftheria, Marine Carpuat. "Beyond Noise: Mitigating the Impact of Fine-grained Semantic Divergences on Neural Machine Translation". arXiv preprint arXiv:2105.15087 (2021) -* Hasler, Eva, Tobias Domhan, Jonay Trenous, Ke Tran, Bill Byrne, Felix Hieber. "Improving the Quality Trade-Off for Neural Machine Translation Multi-Domain Adaptation". Proceedings of EMNLP (2021) +* Hasler, Eva, Tobias Domhan, Sony Trenous, Ke Tran, Bill Byrne, Felix Hieber. "Improving the Quality Trade-Off for Neural Machine Translation Multi-Domain Adaptation". Proceedings of EMNLP (2021) * Tang, Gongbo, Philipp Rönchen, Rico Sennrich, Joakim Nivre. "Revisiting Negation in Neural Machine Translation". Transactions of the Association for Computation Linguistics 9 (2021) * Vu, Thuy, Alessandro Moschitti. "Machine Translation Customization via Automatic Training Data Selection from the Web". arXiv preprint arXiv:2102.1024 (2021) * Xu, Weijia, Marine Carpuat. "EDITOR: An Edit-Based Transformer with Repositioning for Neural Machine Translation with Soft Lexical Constraints." Transactions of the Association for Computation Linguistics 9 (2021) diff --git a/docs/training.md b/docs/training.md index ca02cc88c..4a4a34e72 100644 --- a/docs/training.md +++ b/docs/training.md @@ -175,3 +175,13 @@ that can be enabled by setting `--length-task`, respectively, to `ratio` or to ` Specify `--length-task-layers` to set the number of layers in the prediction MLP. The weight of the loss in the global training objective is controlled with `--length-task-weight` (standard cross-entropy loss has weight 1.0). During inference the predictions can be used to reward longer translations by enabling `--brevity-penalty-type`. + + +## Neural Vocabulary Selection (NVS) + +When Neural Vocabulary Selection (NVS) gets enabled a target bag-of-word model will be trained. +During decoding the output vocabulary gets reduced to the set of predicted target words speeding up decoding +This is similar to using `--restrict-lexicon` for `sockeye-translate` with the advantage that no external alignment model is required and that the contextualized hidden encoder representations are used to predict the set of target words. +To use NVS simply specify `--neural-vocab-selection` to `sockeye-train`. +This will train a model with NVS that is automatically used by `sockeye-translate`. +If you want look at translations without vocabulary selection specify `--skip-nvs` as an argument to `sockeye-translate`. diff --git a/sockeye/__init__.py b/sockeye/__init__.py index dd0d31afd..cb3dc330a 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.13' +__version__ = '3.1.14' diff --git a/sockeye/arguments.py b/sockeye/arguments.py index 358779ed1..a36b457f8 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -326,18 +326,23 @@ def add_rerank_args(params): help="Returns the reranking scores as scores in output JSON objects.") -def add_lexicon_args(params): +def add_lexicon_args(params, is_for_block_lexicon: bool = False): lexicon_params = params.add_argument_group("Model & Top-k") lexicon_params.add_argument("--model", "-m", required=True, help="Model directory containing source and target vocabularies.") - lexicon_params.add_argument("-k", type=int, default=200, - help="Number of target translations to keep per source. Default: %(default)s.") + if not is_for_block_lexicon: + lexicon_params.add_argument("-k", type=int, default=200, + help="Number of target translations to keep per source. Default: %(default)s.") -def add_lexicon_create_args(params): +def add_lexicon_create_args(params, is_for_block_lexicon: bool = False): lexicon_params = params.add_argument_group("I/O") + if is_for_block_lexicon: + input_help = "A text file with tokens that shall be blocked. All token must be in the model vocabulary." + else: + input_help = "Probabilistic lexicon (fast_align format) to build top-k lexicon from." lexicon_params.add_argument("--input", "-i", required=True, - help="Probabilistic lexicon (fast_align format) to build top-k lexicon from.") + help=input_help) lexicon_params.add_argument("--output", "-o", required=True, help="File name to write top-k lexicon to.") @@ -743,6 +748,21 @@ def add_model_parameters(params): 'PyTorch AMP with some additional risk and requires installing Apex: ' 'https://github.com/NVIDIA/apex') + model_params.add_argument('--neural-vocab-selection', + type=str, + default=None, + choices=C.NVS_TYPES, + help='When enabled the model contains a neural vocabulary selection model that restricts ' + 'the target output vocabulary to speed up inference.' + 'logit_max: predictions are made per source token and combined by max pooling.' + 'eos: the prediction is based on the hidden representation of the token.') + + model_params.add_argument('--neural-vocab-selection-block-loss', + action='store_true', + help='When enabled, gradients for NVS are blocked from propagating back to the encoder. ' + 'This means that NVS learns to work with the main model\'s representations but ' + 'does not influence its training.') + def add_batch_args(params, default_batch_size=4096, default_batch_type=C.BATCH_TYPE_WORD): params.add_argument('--batch-size', '-b', @@ -773,6 +793,25 @@ def add_batch_args(params, default_batch_size=4096, default_batch_type=C.BATCH_T 'size 10240). Default: %(default)s.') +def add_nvs_train_parameters(params): + params.add_argument( + '--bow-task-weight', + type=float_greater_or_equal(0.0), + default=1.0, + help= + 'The weight of the auxiliary Bag-of-word (BOW) loss when --neural-vocab-selection is enabled. Default %(default)s.' + ) + + params.add_argument( + '--bow-task-pos-weight', + type=float_greater_or_equal(0.0), + default=10, + help='The weight of the positive class (the set of words present on the target side) for the BOW loss ' + 'when --neural-vocab-selection is set as x * num_negative_class / num_positive_class where x is the ' + '--bow-task-pos-weight. Higher values will bias more towards recall, resulting in larger vocabularies ' + 'at test time trading off larger vocabularies for higher translation quality. Default %(default)s.') + + def add_training_args(params): train_params = params.add_argument_group("Training parameters") @@ -803,6 +842,8 @@ def add_training_args(params): default=1, help='Number of fully-connected layers for predicting the length ratio. Default %(default)s.') + add_nvs_train_parameters(train_params) + train_params.add_argument('--target-factors-weight', type=float, nargs='+', @@ -1203,18 +1244,38 @@ def add_inference_args(params): nargs='+', type=multiple_values(num_values=2, data_type=str), default=None, - help="Specify top-k lexicon to restrict output vocabulary to the k most likely context-" - "free translations of the source words in each sentence (Devlin, 2017). See the " - "lexicon module for creating top-k lexicons. To use multiple lexicons, provide " + help="Specify block or top-k lexicon. A top-k lexicon will pose a positive constraint, " + "by providing the set of allowed target words. While a blocking lexicon poses a " + "negative constraint on providing a set of target words to be avoided. " + "Specifically, a top-k lexicon will restrict the output vocabulary to the k most " + "likely context-free translations of the source words in each sentence " + "(Devlin, 2017). See the lexicon module for creating lexicons, i.e. by running " + "sockeye-lexicon. To use multiple lexicons, provide " "'--restrict-lexicon key1:path1 key2:path2 ...' and use JSON input to specify the " "lexicon for each sentence: " "{\"text\": \"some input string\", \"restrict_lexicon\": \"key\"}. " + "If a single lexicon is specified it will be applied to all inputs. " + "If multiple lexica are specified they can be selected via the JSON input or it " + "can be skipped by not providing a lexicon in the JSON input. " "Default: %(default)s.") decode_params.add_argument('--restrict-lexicon-topk', type=int, default=None, help="Specify the number of translations to load for each source word from the lexicon " - "given with --restrict-lexicon. Default: Load all entries from the lexicon.") + "given with --restrict-lexicon top-k lexicon. " + "Default: Load all entries from the lexicon.") + + decode_params.add_argument('--skip-nvs', + action='store_true', + help='Manually turn off Neural Vocabulary Selection (NVS) to do a softmax over the full target vocabulary.', + default=False) + + decode_params.add_argument('--nvs-thresh', + type=float, + help='The probability threshold for a word to be added to the set of target words. ' + 'Default: 0.5.', + default=0.5) + decode_params.add_argument('--strip-unknown-words', action='store_true', default=False, diff --git a/sockeye/beam_search.py b/sockeye/beam_search.py index 33fefc816..d06d15d3e 100644 --- a/sockeye/beam_search.py +++ b/sockeye/beam_search.py @@ -74,8 +74,8 @@ def state_structure(self) -> List: return [self._model.state_structure()] def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None): - states, predicted_output_length = self._model.encode_and_initialize(inputs, valid_length, self._const_lr) - return states, predicted_output_length + states, predicted_output_length, nvs_prediction = self._model.encode_and_initialize(inputs, valid_length, self._const_lr) + return states, predicted_output_length, nvs_prediction def decode_step(self, step_input: pt.Tensor, @@ -136,13 +136,19 @@ def state_structure(self) -> List: def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None): model_states = [] # type: List[pt.Tensor] predicted_output_lengths = [] # type: List[pt.Tensor] + nvs_predictions = [] for model in self._models: - states, predicted_output_length = model.encode_and_initialize(inputs, valid_length, self._const_lr) + states, predicted_output_length, nvs_prediction = model.encode_and_initialize(inputs, valid_length, self._const_lr) + if nvs_prediction is not None: + nvs_predictions.append(nvs_prediction) + predicted_output_lengths.append(predicted_output_length) model_states += states # average predicted output lengths, (batch, 1) predicted_output_lengths = pt.stack(predicted_output_lengths, dim=1).float().mean(dim=1) # type: ignore - return model_states, predicted_output_lengths + nvs_prediction = pt.stack(nvs_predictions, dim=1).mean(dim=1) if nvs_predictions else None + + return model_states, predicted_output_lengths, nvs_prediction def decode_step(self, step_input: pt.Tensor, @@ -523,17 +529,26 @@ def forward(self, best_hyp_indices, *states): return sorted_states -def _get_vocab_slice_ids(restrict_lexicon: Optional[lexicon.TopKLexicon], +def _get_vocab_slice_ids(restrict_lexicon: lexicon.RestrictLexicon, source_words: pt.Tensor, eos_id: int, beam_size: int, - target_prefix: Optional[pt.Tensor] = None) -> Tuple[pt.Tensor, int]: + target_prefix: Optional[pt.Tensor] = None, + output_vocab_size: Optional[int] = None) -> Tuple[pt.Tensor, int]: device = source_words.device - vocab_slice_ids_np = restrict_lexicon.get_trg_ids(source_words.cpu().int().numpy()) # type: ignore + if not restrict_lexicon.is_blocking(): + vocab_slice_ids_np = restrict_lexicon.get_allowed_trg_ids(source_words.cpu().int().numpy()) # type: ignore + else: + utils.check_condition(output_vocab_size is not None, + "output_vocab_size required for blocking restrict lexicon.") + full_vocab = np.arange(0, output_vocab_size, dtype='int32') + source_ids = source_words.cpu().int().numpy() if restrict_lexicon.requires_src_ids() else None + vocab_slice_ids_np = np.setdiff1d(full_vocab, restrict_lexicon.get_blocked_trg_ids(source_ids), assume_unique=True) + vocab_slice_ids = pt.tensor(vocab_slice_ids_np, device=device, dtype=pt.int64) if target_prefix is not None: # Ensuring that target prefix ids are part of vocab_slice_ids - vocab_slice_ids = pt.concat([vocab_slice_ids, target_prefix.flatten().type(pt.int64)], -1).unique() + vocab_slice_ids = pt.concat([vocab_slice_ids, target_prefix.flatten().type(pt.int64)], -1).unique() # Pad to a multiple of 8. vocab_slice_ids = pt.nn.functional.pad(vocab_slice_ids, pad=(0, 7 - ((vocab_slice_ids.size(-1) - 1) % 8)), @@ -554,6 +569,51 @@ def _get_vocab_slice_ids(restrict_lexicon: Optional[lexicon.TopKLexicon], return vocab_slice_ids, vocab_slice_ids_shape +def _get_nvs_vocab_slice_ids( + nvs_thresh: float, + nvs_prediction: pt.Tensor, + restrict_lexicon: Optional[lexicon.RestrictLexicon] = None, + target_prefix: Optional[pt.Tensor] = None + ): + """ + Return the vocab slice ids based on the Neural Vocabulary Selection model's predictions. + :param nvs_thresh: The threshold for selecting a word (between 0.0 and 1.0). + :param nvs_prediction: Shape: (batch size, vocab_size). + :param restrict_lexicon: An optional blocking lexicon to forcefully turn specific words off. + :param target_prefix: Shape: (batch size, vocab_size). + """ + nvs_prediction_above_thresh = (nvs_prediction > nvs_thresh) + # merge batch dimension (batch size, vocab_size) -> (1, vocab_size) + if nvs_prediction_above_thresh.shape[0] > 1: + nvs_prediction_above_thresh = pt.any(nvs_prediction_above_thresh, dim=0, keepdim=True) + + if restrict_lexicon is not None: + utils.check_condition( + restrict_lexicon.is_blocking() and not restrict_lexicon.requires_src_ids(), + "Only a blocking, static lexicon is supported when Neural Vocabulary Selection (NVS) is used." + ) + blocked_tokens = pt.from_numpy(restrict_lexicon.get_blocked_trg_ids()).long().to(nvs_prediction_above_thresh.device) + nvs_prediction_above_thresh[0, blocked_tokens] = False + + # Add special symbols: + pt_symbols = pt.tensor([C.PAD_ID, C.UNK_ID, C.BOS_ID, C.EOS_ID], device=nvs_prediction_above_thresh.device) + nvs_prediction_above_thresh[0, pt_symbols] = True + + if target_prefix is not None: + nvs_prediction_above_thresh[0, target_prefix.flatten().long()] = True + + bow = nvs_prediction_above_thresh.nonzero(as_tuple=True)[1].unique() + + # pad to a multiple of 8. + if len(bow) % 8 != 0: + bow = pt.nn.functional.pad(bow, (0, 7 - ((len(bow) - 1) % 8)), mode='constant', value=C.EOS_ID) + + output_vocab_size = bow.shape[0] + logger.debug(f'decoder softmax size: {output_vocab_size}') + + return bow, output_vocab_size + + class GreedySearch(pt.nn.Module): """ Implements greedy search, not supporting various features from the BeamSearch class @@ -567,7 +627,9 @@ def __init__(self, device: pt.device, num_source_factors: int, num_target_factors: int, - inference: _SingleModelInference): + inference: _SingleModelInference, + skip_nvs: bool = False, + nvs_thresh: float = 0.5): super().__init__() self.dtype = dtype self.bos_id = bos_id @@ -580,13 +642,15 @@ def __init__(self, self.num_target_factors = num_target_factors self.global_avoid_trie = None assert inference._skip_softmax, "skipping softmax must be enabled for GreedySearch" + self.skip_nvs = skip_nvs + self.nvs_thresh = nvs_thresh self.work_block = GreedyTop1() def forward(self, source: pt.Tensor, source_length: pt.Tensor, - restrict_lexicon: Optional[lexicon.TopKLexicon] = None, + restrict_lexicon: Optional[lexicon.RestrictLexicon] = None, max_output_lengths: pt.Tensor = None, target_prefix: Optional[pt.Tensor] = None, target_prefix_factors: Optional[pt.Tensor] = None) -> SearchResult: @@ -615,17 +679,23 @@ def forward(self, fill_value=self.bos_id, device=self.device, dtype=pt.int32) outputs = [] # type: List[pt.Tensor] + # (0) encode source sentence, returns a list + model_states, _, nvs_prediction = self._inference.encode_and_initialize(source, source_length) + # TODO: check for disabled predicted output length + vocab_slice_ids = None # type: Optional[pt.Tensor] - # If using a top-k lexicon, select param rows for logit computation that correspond to the + # If using a top-k lexicon or NVS select param rows for logit computation that correspond to the # target vocab for this sentence. - if restrict_lexicon: + if nvs_prediction is not None and not self.skip_nvs: + vocab_slice_ids, _ = _get_nvs_vocab_slice_ids(self.nvs_thresh, nvs_prediction, + restrict_lexicon=restrict_lexicon, + target_prefix=target_prefix) + elif restrict_lexicon: source_words = source[:, :, 0] - vocab_slice_ids, _ = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id, - beam_size=1, target_prefix=target_prefix) + vocab_slice_ids, _ = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id, beam_size=1, + target_prefix=target_prefix, + output_vocab_size=self.output_vocab_size) - # (0) encode source sentence, returns a list - model_states, _ = self._inference.encode_and_initialize(source, source_length) - # TODO: check for disabled predicted output length # Prefix masks, where scores are infinity for all other vocabulary items except target_prefix ids prefix_masks, prefix_masks_length = None, 0 @@ -724,7 +794,9 @@ def __init__(self, inference: _Inference, beam_search_stop: str = C.BEAM_SEARCH_STOP_ALL, sample: Optional[int] = None, - prevent_unk: bool = False) -> None: + prevent_unk: bool = False, + skip_nvs: bool = False, + nvs_thresh: float = 0.5) -> None: super().__init__() self.beam_size = beam_size self.dtype = dtype @@ -738,6 +810,8 @@ def __init__(self, self.num_source_factors = num_source_factors self.num_target_factors = num_target_factors self.prevent_unk = prevent_unk + self.skip_nvs = skip_nvs + self.nvs_thresh = nvs_thresh self._repeat_states = RepeatStates(beam_size=beam_size, state_structure=self._inference.state_structure()) self._traced_repeat_states = None # type: Optional[pt.jit.ScriptModule] @@ -762,7 +836,7 @@ def __init__(self, def forward(self, source: pt.Tensor, source_length: pt.Tensor, - restrict_lexicon: Optional[lexicon.TopKLexicon], + restrict_lexicon: Optional[lexicon.RestrictLexicon], max_output_lengths: pt.Tensor, target_prefix: Optional[pt.Tensor] = None, target_prefix_factors: Optional[pt.Tensor] = None) -> SearchResult: @@ -786,9 +860,6 @@ def forward(self, max_iterations = int(max_output_lengths.max().item()) logger.debug("max beam search iterations: %d", max_iterations) - if self._sample is not None: - utils.check_condition(restrict_lexicon is None, "restricted lexicon not available when sampling.") - # General data structure: batch_size * beam_size blocks in total; # a full beam for each sentence, followed by the next beam-block for the next sentence and so on @@ -825,15 +896,40 @@ def forward(self, factor_scores_accumulated = [pt.zeros(batch_size * self.beam_size, self.num_target_factors - 1, device=self.device, dtype=self.dtype)] + # (0) encode source sentence, returns a list + model_states, estimated_reference_lengths, nvs_prediction = self._inference.encode_and_initialize(source, source_length) + # repeat states to beam_size + if self._traced_repeat_states is None: + logger.debug("Tracing repeat_states") + self._traced_repeat_states = pt.jit.trace(self._repeat_states, model_states, strict=False) + model_states = self._traced_repeat_states(*model_states) + # repeat estimated_reference_lengths to shape (batch_size * beam_size) + estimated_reference_lengths = estimated_reference_lengths.repeat_interleave(self.beam_size, dim=0) + output_vocab_size = self.output_vocab_size - # If using a top-k lexicon, select param rows for logit computation that correspond to the + # If using a lexicon or NVS, select param rows for logit computation that correspond to the # target vocab for this sentence. + # NVS additionally can take a blocking lexicon that restricts the output further vocab_slice_ids = None # type: Optional[pt.Tensor] - if restrict_lexicon: + if nvs_prediction is not None and not self.skip_nvs: + vocab_slice_ids, output_vocab_size = _get_nvs_vocab_slice_ids(self.nvs_thresh, nvs_prediction, + restrict_lexicon=restrict_lexicon, + target_prefix=target_prefix) + elif restrict_lexicon: source_words = source[:, :, 0] - vocab_slice_ids, output_vocab_size = _get_vocab_slice_ids(restrict_lexicon, source_words, self.eos_id, - beam_size=1, target_prefix=target_prefix) + vocab_slice_ids, output_vocab_size = _get_vocab_slice_ids(restrict_lexicon, + source_words, + self.eos_id, + beam_size=self.beam_size, + target_prefix=target_prefix, + output_vocab_size=self.output_vocab_size) + + if self._sample is not None: + utils.check_condition( + vocab_slice_ids is None, + "Vocabulary restriction (via lexicon or NVS) not available when sampling." + ) pad_dist = pt.full((1, output_vocab_size), fill_value=np.inf, device=self.device, dtype=self.dtype) pad_dist[0, 0] = 0 # [0, inf, inf, ...] @@ -841,16 +937,6 @@ def forward(self, fill_value=np.inf, device=self.device, dtype=self.dtype) eos_dist[:, C.EOS_ID] = 0 - # (0) encode source sentence, returns a list - model_states, estimated_reference_lengths = self._inference.encode_and_initialize(source, source_length) - # repeat states to beam_size - if self._traced_repeat_states is None: - logger.debug("Tracing repeat_states") - self._traced_repeat_states = pt.jit.trace(self._repeat_states, model_states, strict=False) - model_states = self._traced_repeat_states(*model_states) - # repeat estimated_reference_lengths to shape (batch_size * beam_size) - estimated_reference_lengths = estimated_reference_lengths.repeat_interleave(self.beam_size, dim=0) - # Prefix token masks, where scores are infinity for all other vocabulary items except target_prefix ids prefix_masks, prefix_masks_length = None, 0 if target_prefix is not None: @@ -912,7 +998,7 @@ def forward(self, best_hyp_indices = best_hyp_indices + offset # Map from restricted to full vocab ids if needed - if restrict_lexicon: + if vocab_slice_ids is not None: best_word_indices = vocab_slice_ids.index_select(0, best_word_indices) # (4) Normalize the scores of newly finished hypotheses. Note that after this until the @@ -983,7 +1069,9 @@ def get_search_algorithm(models: List[SockeyeModel], constant_length_ratio: float = 0.0, sample: Optional[int] = None, prevent_unk: bool = False, - greedy: bool = False) -> Union[BeamSearch, GreedySearch]: + greedy: bool = False, + skip_nvs: bool = False, + nvs_thresh: Optional[float] = None) -> Union[BeamSearch, GreedySearch]: """ Returns an instance of BeamSearch or GreedySearch depending. @@ -1007,7 +1095,9 @@ def get_search_algorithm(models: List[SockeyeModel], num_target_factors=models[0].num_target_factors, inference=_SingleModelInference(model=models[0], skip_softmax=True, - constant_length_ratio=0.0)) + constant_length_ratio=0.0), + skip_nvs=skip_nvs, + nvs_thresh=nvs_thresh) else: inference = None # type: Optional[_Inference] if len(models) == 1: @@ -1035,7 +1125,9 @@ def get_search_algorithm(models: List[SockeyeModel], num_source_factors=models[0].num_source_factors, num_target_factors=models[0].num_target_factors, prevent_unk=prevent_unk, - inference=inference + inference=inference, + skip_nvs=skip_nvs, + nvs_thresh=nvs_thresh ) return search diff --git a/sockeye/constants.py b/sockeye/constants.py index 6412f8b5e..97d51ed5c 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -40,6 +40,7 @@ ENCODER_PREFIX = "encoder" DECODER_PREFIX = "decoder" DEFAULT_OUTPUT_LAYER_PREFIX = "output_layer" +NVS_LAYER_PREFIX = "nvs" # SSRU SSRU_PREFIX = "ssru_" @@ -86,6 +87,10 @@ WEIGHT_TYING_SRC_TRG_SOFTMAX = 'src_trg_softmax' WEIGHT_TYING_TYPES = [WEIGHT_TYING_NONE, WEIGHT_TYING_SRC_TRG_SOFTMAX, WEIGHT_TYING_SRC_TRG, WEIGHT_TYING_TRG_SOFTMAX] +NVS_TYPE_LOGIT_MAX = "logit_max" +NVS_TYPE_EOS = "eos" +NVS_TYPES = [NVS_TYPE_LOGIT_MAX, NVS_TYPE_EOS] + # Activation types RELU = "relu" # Swish-1/SiLU (https://arxiv.org/pdf/1710.05941.pdf, https://arxiv.org/pdf/1702.03118.pdf) @@ -103,6 +108,8 @@ LOGITS_NAME = "logits" FACTOR_LOGITS_NAME = "factor%d_logits" +NVS_PRED_NAME = "nvs_pred" + MEASURE_SPEED_EVERY = 50 # measure speed and metrics every X batches # Inference constants @@ -252,19 +259,21 @@ ROUGE1 = 'rouge1' ROUGE2 = 'rouge2' ROUGEL = 'rougel' +BOW_PERPLEXITY = 'bow-perplexity' TER = 'ter' LENRATIO = 'length-ratio-mse' AVG_TIME = "avg-sec-per-sent" DECODING_TIME = "decode-walltime" -METRICS = [PERPLEXITY, ACCURACY, LENRATIO_MSE, BLEU, CHRF, ROUGE1, TER] +METRICS = [PERPLEXITY, ACCURACY, LENRATIO_MSE, BLEU, CHRF, ROUGE1, BOW_PERPLEXITY, TER] METRIC_MAXIMIZE = {ACCURACY: True, BLEU: True, CHRF: True, ROUGE1: True, PERPLEXITY: False, LENRATIO_MSE: False, - TER: False} -METRIC_WORST = {ACCURACY: 0.0, BLEU: 0.0, CHRF: 0.0, ROUGE1: 0.0, PERPLEXITY: np.inf, TER: np.inf} + TER: False, BOW_PERPLEXITY: False} +METRIC_WORST = {ACCURACY: 0.0, BLEU: 0.0, CHRF: 0.0, ROUGE1: 0.0, PERPLEXITY: np.inf, BOW_PERPLEXITY: np.inf, TER: np.inf} METRICS_REQUIRING_DECODER = [BLEU, CHRF, ROUGE1, ROUGE2, ROUGEL, TER] EVALUATE_METRICS = [BLEU, CHRF, ROUGE1, ROUGE2, ROUGEL, TER] # loss CROSS_ENTROPY = 'cross-entropy' +BINARY_CROSS_ENTROPY = 'binary-cross-entropy' LINK_NORMAL = 'normal' LINK_POISSON = 'poisson' LENGTH_TASK_RATIO = 'ratio' diff --git a/sockeye/encoder.py b/sockeye/encoder.py index 61181d1d1..cdf6fd2f0 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -182,7 +182,7 @@ def __init__(self, config: transformer.TransformerConfig, inference_only: bool = dropout=config.dropout_prepost, num_hidden=self.config.model_size) - def forward(self, data: pt.Tensor, valid_length: pt.Tensor) -> Tuple[pt.Tensor, pt.Tensor]: + def forward(self, data: pt.Tensor, valid_length: pt.Tensor) -> Tuple[pt.Tensor, pt.Tensor, pt.Tensor]: # positional embedding data = self.pos_embedding(data) @@ -190,8 +190,10 @@ def forward(self, data: pt.Tensor, valid_length: pt.Tensor) -> Tuple[pt.Tensor, data = self.dropout(data) _, max_len, __ = data.size() - # length_mask for source attention masking. Shape: (batch_size * heads, 1, max_len) - att_mask = layers.prepare_source_length_mask(valid_length, self.config.attention_heads, max_length=max_len) + # length_mask for source attention masking. Shape: (batch_size, max_len) + single_head_att_mask = layers.prepare_source_length_mask(valid_length, self.config.attention_heads, max_length=max_len, expand=False) + # Shape: (batch_size, max_len) -> (batch_size * heads, 1, max_len) + att_mask = single_head_att_mask.unsqueeze(1).expand(-1, self.config.attention_heads, -1).reshape((-1, max_len)).unsqueeze(1) att_mask = att_mask.expand(-1, max_len, -1) data = data.transpose(1, 0) # batch to time major @@ -200,7 +202,7 @@ def forward(self, data: pt.Tensor, valid_length: pt.Tensor) -> Tuple[pt.Tensor, data = self.final_process(data) data = data.transpose(1, 0) # time to batch major - return data, valid_length + return data, valid_length, single_head_att_mask def get_num_hidden(self) -> int: """ diff --git a/sockeye/inference.py b/sockeye/inference.py index ff6b2bf63..e03829eab 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -136,7 +136,7 @@ class TranslatorInput: target_prefix_factors: Optional[List[Tokens]] = None use_target_prefix_all_chunks: Optional[bool] = True keep_target_prefix_key: Optional[bool] = True - restrict_lexicon: Optional[lexicon.TopKLexicon] = None + restrict_lexicon: Optional[lexicon.RestrictLexicon] = None constraints: Optional[List[Tokens]] = None avoid_list: Optional[List[Tokens]] = None pass_through_dict: Optional[Dict] = None @@ -153,7 +153,7 @@ def num_factors(self) -> int: Returns the number of factors of this instance. """ return 1 + (0 if not self.factors else len(self.factors)) - + def get_source_prefix_tokens(self) -> Tokens: """ Returns the source prefix tokens of this instance. @@ -369,16 +369,11 @@ def make_input_from_dict(sentence_id: SentenceId, use_target_prefix_all_chunks = input_dict.get(C.JSON_USE_TARGET_PREFIX_ALL_CHUNKS_KEY, True) keep_target_prefix_key = input_dict.get(C.JSON_KEEP_TARGET_PREFIX_KEY, True) # Lexicon for vocabulary selection/restriction: - # This is only populated when using multiple lexicons, in which case the - # restrict_lexicon key must exist and the value (name) must map to one - # of the translator's known lexicons. + # This is only populated when using multiple lexicons and the lexicon name is given, in which case the + # restrict_lexicon key must exist and the value (name) must map to one of the translator's known lexicons. restrict_lexicon = None - restrict_lexicon_name = input_dict.get(C.JSON_RESTRICT_LEXICON_KEY) - if isinstance(translator.restrict_lexicon, dict): - if restrict_lexicon_name is None: - logger.error("Must specify restrict_lexicon when using multiple lexicons. Choices: %s" - % ' '.join(sorted(translator.restrict_lexicon))) - return _bad_input(sentence_id, reason=str(input_dict)) + restrict_lexicon_name = input_dict.get(C.JSON_RESTRICT_LEXICON_KEY, None) + if isinstance(translator.restrict_lexicon, dict) and restrict_lexicon_name is not None: restrict_lexicon = translator.restrict_lexicon.get(restrict_lexicon_name, None) if restrict_lexicon is None: logger.error("Unknown restrict_lexicon '%s'. Choices: %s" @@ -731,8 +726,9 @@ class Translator: :param source_vocabs: Source vocabularies. :param target_vocabs: Target vocabularies. :param nbest_size: Size of nbest list of translations. - :param restrict_lexicon: Top-k lexicon to use for target vocabulary selection. Can be a dict of - of named lexicons. + :param restrict_lexicon: Lexicon to use for target vocabulary selection. Can be a dict of named lexicons. When + it is a single lexicon it will be applied to all inputs. If is a Dict the lexicon with the given name will + be used or no lexicon be used if the name is None. :param strip_unknown_words: If True, removes any symbols from outputs. :param sample: If True, sample from softmax multinomial instead of using topk. :param output_scores: Whether the scores will be needed as outputs. If True, scores will be normalized, negative @@ -747,6 +743,8 @@ class Translator: :param max_output_length: Maximum output length this Translator is allowed to decode. If None, value will be taken from the model(s). Decodings that do not finish within this limit, will be force-stopped. If model(s) do not support given input length it will fall back to what the model(s) support. + :param skip_nvs: Manually turn off Neural Vocabulary Selection (NVS) to do a softmax over the full target vocabulary. + :param nvs_thresh: The probability threshold for a word to be added to the set of target words. Default: 0.5. """ def __init__(self, @@ -760,16 +758,18 @@ def __init__(self, target_vocabs: List[vocab.Vocab], beam_size: int = 5, nbest_size: int = 1, - restrict_lexicon: Optional[Union[lexicon.TopKLexicon, Dict[str, lexicon.TopKLexicon]]] = None, + restrict_lexicon: Optional[Union[lexicon.RestrictLexicon, Dict[str, lexicon.RestrictLexicon]]] = None, strip_unknown_words: bool = False, - sample: int = None, + sample: Optional[int] = None, output_scores: bool = False, constant_length_ratio: float = 0.0, max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH, max_input_length: Optional[int] = None, max_output_length: Optional[int] = None, prevent_unk: bool = False, - greedy: bool = False) -> None: + greedy: bool = False, + skip_nvs: bool = False, + nvs_thresh: float = 0.5) -> None: self.device = device self.dtype = models[0].dtype self._scorer = scorer @@ -813,14 +813,16 @@ def __init__(self, scorer=self._scorer, constant_length_ratio=constant_length_ratio, prevent_unk=prevent_unk, - greedy=greedy) + greedy=greedy, + skip_nvs=skip_nvs, + nvs_thresh=nvs_thresh) self._concat_translations = partial(_concat_nbest_translations if self.nbest_size > 1 else _concat_translations, stop_ids=self.stop_ids, scorer=self._scorer) # type: Callable logger.info("Translator (%d model(s) beam_size=%d algorithm=%s, beam_search_stop=%s max_input_length=%s " - "nbest_size=%s ensemble_mode=%s max_batch_size=%d dtype=%s)", + "nbest_size=%s ensemble_mode=%s max_batch_size=%d dtype=%s skip_nvs=%s nvs_thresh=%s)", len(self.models), self.beam_size, "GreedySearch" if isinstance(self._search, GreedySearch) else "BeamSearch", @@ -829,7 +831,9 @@ def __init__(self, self.nbest_size, "None" if len(self.models) == 1 else ensemble_mode, self.max_batch_size, - self.dtype) + self.dtype, + skip_nvs, + nvs_thresh) @property def max_input_length(self) -> int: @@ -981,7 +985,7 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = def _get_inference_input(self, trans_inputs: List[TranslatorInput]) -> Tuple[pt.Tensor, pt.Tensor, - Optional[lexicon.TopKLexicon], + Optional[lexicon.RestrictLexicon], pt.Tensor, Optional[pt.Tensor], Optional[pt.Tensor]]: @@ -1008,7 +1012,7 @@ def _get_inference_input(self, target_prefix_factors_np = np.zeros((batch_size, max_target_prefix_factors_length, self.num_target_factors - 1), dtype='int32') \ if self.num_target_factors > 1 and max_target_prefix_factors_length > 0 else None - restrict_lexicon = None # type: Optional[lexicon.TopKLexicon] + restrict_lexicon = None # type: Optional[lexicon.RestrictLexicon] max_output_lengths = [] # type: List[int] for j, trans_input in enumerate(trans_inputs): @@ -1053,15 +1057,14 @@ def _get_inference_input(self, restrict_lexicon = trans_input.restrict_lexicon elif self.restrict_lexicon is not None: if isinstance(self.restrict_lexicon, dict): - # This code should not be reachable since the case is checked when creating - # translator inputs. It is included here to guarantee that the translator can - # handle any valid input regardless of whether it was checked at creation time. - logger.warning("Sentence %s: no restrict_lexicon specified for input when using multiple lexicons, " - "defaulting to first lexicon for entire batch." % trans_input.sentence_id) - restrict_lexicon = list(self.restrict_lexicon.values())[0] + restrict_lexicon = None else: restrict_lexicon = self.restrict_lexicon + if restrict_lexicon is None and isinstance(self.restrict_lexicon, dict): + logger.info("No restrict_lexicon specified for input when using multiple lexicons, " + "will default to not using a restrict lexicon.") + source = pt.tensor(source_np, device=self.device, dtype=pt.int32) source_length = pt.tensor(lengths, device=self.device, dtype=pt.int32) # shape: (batch_size,) max_out_lengths = pt.tensor(max_output_lengths, device=self.device, dtype=pt.int32) @@ -1159,7 +1162,7 @@ def _make_result(self, def _translate_np(self, source: pt.Tensor, source_length: pt.Tensor, - restrict_lexicon: Optional[lexicon.TopKLexicon], + restrict_lexicon: Optional[lexicon.RestrictLexicon], max_output_lengths: pt.Tensor, target_prefix: Optional[pt.Tensor] = None, target_prefix_factors: Optional[pt.Tensor] = None) -> List[Translation]: diff --git a/sockeye/layers.py b/sockeye/layers.py index 6702b7eba..54dfc7155 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -257,10 +257,17 @@ def forward(self, return interleaved_matmul_encdec_valatt(key_values, probs, heads=self.heads) -def prepare_source_length_mask(lengths: pt.Tensor, heads: int, max_length: int) -> pt.Tensor: - lengths = lengths.view(-1, 1).expand(-1, heads).reshape(-1, 1) # (batch_size * heads, 1) - # (batch_size * heads, 1, max_len) - return ~(pt.arange(max_length, device=lengths.device).unsqueeze(0) < lengths).view(-1, 1, max_length) +def prepare_source_length_mask(lengths: pt.Tensor, heads: int, max_length: int, expand=True) -> pt.Tensor: + """ + lengths: (batch_size,) + expand: Expand to the heads. + """ + # (batch_size, max_len) + mask = ~(pt.arange(max_length, device=lengths.device).unsqueeze(0) < lengths.reshape((-1, 1))) + if expand: + # (batch_size*heads, 1, max_len) + mask = mask.unsqueeze(1).expand(-1, heads, -1).reshape((-1, max_length)).unsqueeze(1) + return mask class MultiHeadAttentionBase(pt.nn.Module): diff --git a/sockeye/lexicon.py b/sockeye/lexicon.py index 2e89eee74..3a5666fe6 100644 --- a/sockeye/lexicon.py +++ b/sockeye/lexicon.py @@ -1,4 +1,4 @@ -# Copyright 2017--2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not # use this file except in compliance with the License. A copy of the License @@ -12,16 +12,20 @@ # permissions and limitations under the License. import argparse +import collections import os import sys import time import logging from itertools import groupby from operator import itemgetter -from typing import Dict, Generator, Tuple, Optional +from typing import Dict, Generator, List, Tuple, Optional +from abc import abstractmethod, ABC import numpy as np +from sockeye.data_io import SequenceReader + from . import arguments from . import constants as C from . import vocab @@ -84,7 +88,76 @@ def read_lexicon(path: str, vocab_source: Dict[str, int], vocab_target: Dict[str return lexicon -class TopKLexicon: +class RestrictLexicon(ABC): + """ + Lexicon component that potentially restricts the set of output words. + + If `is_blocking()` is True the set of target ids pose a negative constraint as tokens ids that must not be used on + the target side. Conversely, if `is_blocking` is False the lexicon poses a positive constraint of returning the set + of allowed target words. + """ + + lex: Optional[np.ndarray] = None + + def save(self, path: str): + """ + Save lexicon in Numpy array format. Lexicon will be specific to Sockeye model. + + :param path: Path to Numpy array output file. + """ + assert self.lex is not None, "Lexicon uninitialized, can't be saved." + with open(path, 'wb') as out: + np.save(out, self.lex) + logger.info("Saved lexicon to \"%s\"", path) + + @abstractmethod + def load_np(self, lex: np.ndarray, k: Optional[int] = None): + raise NotImplementedError() + + @abstractmethod + def requires_src_ids(self) -> bool: + """ If true src_ids are required as an argument to get_trg_ids. Otherwise the set of target ids are source + independent and `None` may be passed instead. """ + raise NotImplementedError() + + @abstractmethod + def is_blocking(self) -> bool: + """ If true use get_blocked_trg_ids to obtain blocked ids, otherwise use get_allowed_trg_ids to get allowed + target ids(inverts the meaning of the target ids).""" + raise NotImplementedError() + + @abstractmethod + def get_allowed_trg_ids(self, src_ids: Optional[np.ndarray] = None) -> np.ndarray: + raise NotImplementedError() + + @abstractmethod + def get_blocked_trg_ids(self, src_ids: Optional[np.ndarray] = None) -> np.ndarray: + raise NotImplementedError() + + +def load_restrict_lexicon( + path: str, + vocab_source: Optional[Dict[str, int]] = None, + vocab_target: Optional[Dict[str, int]] = None, + k: Optional[int] = None) -> RestrictLexicon: + load_time_start = time.time() + with open(path, 'rb') as inp: + lex = np.load(inp) + load_time = time.time() - load_time_start + # Both lexicon types are serialized as numpy arrays and we distinguish them by their shape + logger.info("Loaded lexicon from \"%s\" in %.4fs.", path, load_time) + if len(lex.shape) == 1: + lexicon = StaticBlockLexicon() # type: RestrictLexicon + lexicon.load_np(lex) + elif len(lex.shape) == 2: + lexicon = TopKLexicon(vocab_source, vocab_target) + lexicon.load_np(lex, k=k) + else: + raise ValueError("Expected a 1d or 2d array.") + return lexicon + + +class TopKLexicon(RestrictLexicon): """ Lexicon component that stores the k most likely target words for each source word. Used during decoding to restrict target vocabulary for each source sequence. @@ -131,27 +204,9 @@ def create(self, path: str, k: int = 20): logger.info("Created top-k lexicon from \"%s\", k=%d. %d source tokens with fewer than %d translations", path, k, num_insufficient, k) - def save(self, path: str): - """ - Save lexicon in Numpy array format. Lexicon will be specific to Sockeye model. - - :param path: Path to Numpy array output file. - """ - with open(path, 'wb') as out: - np.save(out, self.lex) - logger.info("Saved top-k lexicon to \"%s\"", path) - - def load(self, path: str, k: Optional[int] = None): - """ - Load lexicon from Numpy array file. The top-k target ids will be sorted by increasing target id. - - :param path: Path to Numpy array file. - :param k: Optionally load less items than stored in path. - """ + def load_np(self, lex: np.ndarray, k: Optional[int] = None): load_time_start = time.time() - with open(path, 'rb') as inp: - _lex = np.load(inp) - loaded_k = _lex.shape[1] + loaded_k = lex.shape[1] if k is not None: top_k = min(k, loaded_k) if k > loaded_k: @@ -159,13 +214,38 @@ def load(self, path: str, k: Optional[int] = None): "contains at most %d entries per source.", k, loaded_k) else: top_k = loaded_k - self.lex = np.zeros((len(self.vocab_source), top_k), dtype=_lex.dtype) - for src_id, trg_ids in enumerate(_lex): + self.lex = np.zeros((len(self.vocab_source), top_k), dtype=lex.dtype) + for src_id, trg_ids in enumerate(lex): self.lex[src_id, :] = np.sort(trg_ids[:top_k]) load_time = time.time() - load_time_start - logger.info("Loaded top-%d lexicon from \"%s\" in %.4fs.", top_k, path, load_time) + logger.info("Created top-%d lexicon in %.4fs.", top_k, load_time) + + def load(self, path: str, k: Optional[int] = None): + """ + Load lexicon from Numpy array file. The top-k target ids will be sorted by increasing target id. + + :param path: Path to Numpy array file. + :param k: Optionally load less items than stored in path. + """ + load_time_start = time.time() + with open(path, 'rb') as inp: + lex = np.load(inp) + load_time = time.time() - load_time_start + logger.info("Loaded lexicon from \"%s\" in %.4fs.", path, load_time) + return self.load_np(lex, k) + + def requires_src_ids(self): + return True + + def is_blocking(self) -> bool: + return False def get_trg_ids(self, src_ids: np.ndarray) -> np.ndarray: + # Note: we have this function for backwards compatibility when `get_trg_ids` was the only function that returned + # allowed target ids + return self.get_allowed_trg_ids(src_ids) + + def get_allowed_trg_ids(self, src_ids: Optional[np.ndarray] = None) -> np.ndarray: """ Lookup possible target ids for input sequence of source ids. @@ -177,6 +257,56 @@ def get_trg_ids(self, src_ids: np.ndarray) -> np.ndarray: logger.debug(f"lookup: {trg_ids.shape[0]} unique targets for {unique_src_ids.shape[0]} unique sources") return trg_ids + def get_blocked_trg_ids(self, src_ids): + raise NotImplementedError() + + +class StaticBlockLexicon(RestrictLexicon): + """ + A lexicon that blocks a fixed set of target ids independent of the src_ids. + """ + + def __init__(self, lex: Optional[np.ndarray] = None): + if lex is not None: + self.lex = lex + + def create(self, block_tokens: List[str], vocab_target: Dict[str, List[int]]): + # We do not default to UNK because we want to only block on real tokens + # We also exclude any other special symbols + block_tokens_set = set(block_tokens) + logger.info(f"Creating static block lexicon with tokens: {block_tokens_set}") + num_not_in_vocab = 0 + block_token_ids = [] + for token in block_tokens: + if token in C.VOCAB_SYMBOLS: + continue + if token not in vocab_target: + num_not_in_vocab += 1 + continue + block_token_ids.extend(vocab_target[token]) + block_token_ids = list(set(block_token_ids)) + + self.lex = np.array(block_token_ids, dtype='int32') + logger.info("Created static block lexicon with %d tokens, %d skipped because they were not in the vocabulary", + len(block_token_ids), + num_not_in_vocab) + + def load_np(self, lex: np.ndarray, k: Optional[int] = None): + self.lex = lex + + def requires_src_ids(self): + return False + + def is_blocking(self): + return True + + def get_blocked_trg_ids(self, src_ids: Optional[np.ndarray] = None) -> np.ndarray: + assert self.lex is not None, "Lexicon not loaded yet." + return self.lex + + def get_allowed_trg_ids(self, src_ids): + raise NotImplementedError() + def create(args): setup_main_logger(console=not args.quiet, file_logging=not args.no_logfile, path=args.output + ".log") @@ -193,6 +323,43 @@ def create(args): lexicon.save(args.output) + +def create_block_lexicon_from_file(args): + setup_main_logger(console=not args.quiet, file_logging=not args.no_logfile, path=args.output + ".log") + global logger + logger = logging.getLogger('create-block') + log_sockeye_version(logger) + + fname = args.input + model_path = args.model + output_path = args.output + with open(fname) as data: + block_tokens = list(set(token for line in data for token in line.rstrip().split())) + return create_block_lexicon_for_model(block_tokens, model_path, output_path) + + +def create_block_lexicon_for_model(block_tokens: List[str], model_path: str, output_path: str, lowercase: bool = False): + vocab_target = vocab.load_target_vocabs(model_path)[0] + return create_block_lexicon(block_tokens, vocab_target, output_path, lowercase) + + +def create_block_lexicon(block_tokens: List[str], vocab_target: vocab.Vocab, output_path: str, lowercase: bool = False): + if lowercase: + # Lowercase vocabulary entries + block words: + # lowercased entries map to multiple word ids + vocab_target_lower = collections.defaultdict(list) + for k, v in vocab_target.items(): + vocab_target_lower[k.lower()].append(v) + block_tokens = [token.lower() for token in block_tokens] + vocab_target_for_lexicon = dict(vocab_target_lower) + else: + vocab_target_for_lexicon = {k: [v] for k, v in vocab_target.items()} + + lexicon = StaticBlockLexicon() + lexicon.create(block_tokens, vocab_target_for_lexicon) + lexicon.save(output_path) + + def inspect(args): from .data_io import tokens2ids setup_main_logger(console=True, file_logging=False) @@ -212,7 +379,7 @@ def inspect(args): continue ids = tokens2ids(tokens, vocab_source) print("Input: n=%d" % len(tokens), " ".join("%s(%d)" % (tok, i) for tok, i in zip(tokens, ids))) - trg_ids = lexicon.get_trg_ids(np.array(ids)) + trg_ids = lexicon.get_allowed_trg_ids(np.array(ids)) tokens_trg = [vocab_target_inv.get(trg_id, C.UNK_SYMBOL) for trg_id in trg_ids] print("Output: n=%d" % len(tokens_trg), " ".join("%s(%d)" % (tok, i) for tok, i in zip(tokens_trg, trg_ids))) print() @@ -233,6 +400,12 @@ def main(): arguments.add_logging_args(params_create) params_create.set_defaults(func=create) + params_block = subparams.add_parser('create-block', description="Create block lexicon for use during decoding.") + arguments.add_lexicon_args(params_block, is_for_block_lexicon=True) + arguments.add_lexicon_create_args(params_block, is_for_block_lexicon=True) + arguments.add_logging_args(params_block) + params_block.set_defaults(func=create_block_lexicon_from_file) + params_inspect = subparams.add_parser('inspect', description="Inspect top-k lexicon for use during decoding.") arguments.add_lexicon_inspect_args(params_inspect) arguments.add_lexicon_args(params_inspect) diff --git a/sockeye/loss.py b/sockeye/loss.py index 5a097e7de..acf2131a5 100644 --- a/sockeye/loss.py +++ b/sockeye/loss.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Optional, Tuple import torch as pt +import numpy as np from . import constants as C from . import utils @@ -225,6 +226,98 @@ def create_metric(self) -> 'LossMetric': return PerplexityMetric(prefix=self._metric_prefix) +class DynamicBCEWithLogitsLoss(pt.nn.BCEWithLogitsLoss): + """ A version of BCEWithLogitsLoss where the pos_weight can be supplied dynamically in the `forward` call. """ + + def __init__(self, weight: Optional[pt.Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean', + pos_weight: Optional[pt.Tensor] = None) -> None: + super().__init__(reduction=reduction) + self.register_buffer('weight', weight) + self.register_buffer('pos_weight', pos_weight) + self.weight: Optional[pt.Tensor] + self.pos_weight: Optional[pt.Tensor] + + def forward(self, input: pt.Tensor, target: pt.Tensor, pos_weight: Optional[pt.Tensor] = None) -> pt.Tensor: + if pos_weight is None: + pos_weight = self.pos_weight + + return pt.nn.functional.binary_cross_entropy_with_logits( + input, + target, + self.weight, + pos_weight=pos_weight, + reduction=self.reduction) + + +@pt.jit.script +def _label_to_bow(label: pt.Tensor, num_labels: int): + bow = pt.zeros(label.shape[0], num_labels, device=label.device) + bow[pt.arange(0, label.shape[0], dtype=pt.int64)[:, np.newaxis], label.long()] = 1. + return bow + + +class BinaryCrossEntropyBowLoss(Loss): + """ + Computes the binary cross entropy loss over a bag-of-words of target tokens. + """ + + def __init__(self, + name: str = C.BINARY_CROSS_ENTROPY, + pos_weight: float = 1.0, + weight: float = 1.0, + dtype: str = C.DTYPE_FP32, + output_name: str = C.NVS_PRED_NAME, + label_name: str = C.TARGET_LABEL_NAME, + num_labels: int = 0, + metric_prefix: str = '') -> None: + super().__init__(name=name, output_name=output_name, label_name=label_name, + weight=weight, metric_prefix=metric_prefix) + self._dtype = dtype + assert num_labels != 0, "num_labels required" + self._num_labels = num_labels + self.ce_loss = DynamicBCEWithLogitsLoss(reduction='none') + self.pos_weight = pos_weight + + def forward(self, output: pt.Tensor, label: pt.Tensor): + """ + pred: (batch_size, num_vocab) probabilities. + labels: (batch_size, target_length) words. + """ + nvs_pred = output + + bow = _label_to_bow(label, self._num_labels) + + # Set automatically using positive and negative counts + num_positive = pt.sum(bow).float() + num_total = bow.shape[0] * bow.shape[1] + num_negative = num_total - num_positive + pos_weight = self.pos_weight * num_negative / num_positive + + # instead of normalizing 1/num_labels, as done by the ce block, we want to also + # normalize by the virtual positive counts implied by the pos_weight + # Everything is one per sentence, so we get the average positive cases + # convert it to the additional (therefore pos_weight-1) implied counts + # and renormalize + avg_pos_count = pt.mean(pt.sum(bow, dim=1).float()) + implied_pos_count = avg_pos_count * (pos_weight-1) + scale = 1. / (self._num_labels + implied_pos_count) + + # shape: (batch_size, vocab_size) + loss = self.ce_loss(nvs_pred, bow, pos_weight) + + # shape: (batch_size,) + loss = pt.sum(loss, 1) * scale + + # Remove the batch dimension + # (1,) + ce = pt.mean(loss) * self.weight + + return ce, pt.ones(1, device=ce.device) + + def create_metric(self) -> 'LossMetric': + return PerplexityMetric(prefix=self._metric_prefix) + + class PerplexityMetric(LossMetric): def __init__(self, prefix: str = '', name: str = C.PERPLEXITY, short_name: str = C.PERPLEXITY_SHORT_NAME) -> None: diff --git a/sockeye/model.py b/sockeye/model.py index 8cb7d2bed..433c0fae8 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -33,6 +33,7 @@ from .config import Config from .encoder import FactorConfig from .layers import LengthRatioConfig +from . import nvs logger = logging.getLogger(__name__) @@ -55,6 +56,9 @@ class ModelConfig(Config): :param weight_tying_type: Determines which weights get tied. :param lhuc: LHUC (Vilar 2018) is applied at some part of the model. :param dtype: Data type of model parameters. Default: float32. + :param neural_vocab_selection: When True the model contains a neural vocab selection model that restricts + the target output vocabulary to speed up inference. + :param neural_vocab_selection_block_loss: When true the gradients of the NVS models are blocked before the encoder. """ config_data: data_io.DataConfig vocab_source_size: int @@ -67,6 +71,8 @@ class ModelConfig(Config): weight_tying_type: str = C.WEIGHT_TYING_SRC_TRG_SOFTMAX lhuc: bool = False dtype: str = C.DTYPE_FP32 + neural_vocab_selection: Optional[str] = None + neural_vocab_selection_block_loss: bool = False class SockeyeModel(pt.nn.Module): @@ -111,6 +117,13 @@ def __init__(self, self.encoder = encoder.get_transformer_encoder(self.config.config_encoder, inference_only=inference_only) self.decoder = decoder.get_decoder(self.config.config_decoder, inference_only=inference_only) + self.nvs = None + if self.config.neural_vocab_selection: + self.nvs = nvs.NeuralVocabSelection( + model_size=self.config.config_encoder.model_size, + vocab_target_size=self.config.vocab_target_size, + model_type=self.config.neural_vocab_selection + ) self.output_layer = layers.OutputLayer(hidden_size=self.decoder.get_num_hidden(), vocab_size=self.config.vocab_target_size, @@ -161,13 +174,13 @@ def cast(self, dtype: str): def state_structure(self): return self.decoder.state_structure() - def encode(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None) -> Tuple[pt.Tensor, pt.Tensor]: + def encode(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None) -> Tuple[pt.Tensor, pt.Tensor, pt.Tensor]: """ Encodes the input sequence. :param inputs: Source input data. Shape: (batch_size, length, num_source_factors). :param valid_length: Optional Tensor of sequence lengths within this batch. Shape: (batch_size,) - :return: Encoder outputs, encoded output lengths + :return: Encoder outputs, encoded output lengths, attention mask """ if self.traced_embedding_source is None: logger.debug("Tracing embedding_source") @@ -176,11 +189,12 @@ def encode(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None) -> if self.traced_encoder is None: logger.debug("Tracing encoder") self.traced_encoder = pt.jit.trace(self.encoder, (source_embed, valid_length)) - source_encoded, source_encoded_length = self.traced_encoder(source_embed, valid_length) - return source_encoded, source_encoded_length + source_encoded, source_encoded_length, att_mask = self.traced_encoder(source_embed, valid_length) + return source_encoded, source_encoded_length, att_mask def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Tensor] = None, - constant_length_ratio: float = 0.0) -> Tuple[List[pt.Tensor], pt.Tensor]: + constant_length_ratio: float = 0.0) -> Tuple[List[pt.Tensor], pt.Tensor, + Optional[pt.Tensor]]: """ Encodes the input sequence and initializes decoder states (and predicted output lengths if available). Used for inference/decoding. @@ -189,22 +203,27 @@ def encode_and_initialize(self, inputs: pt.Tensor, valid_length: Optional[pt.Ten :param valid_length: Optional Tensor of sequence lengths within this batch. Shape: (batch_size,) :param constant_length_ratio: Constant length ratio :return: Initial states for the decoder, predicted output length of shape (batch_size,), 0 if not available. + Returns the neural vocabulary selection model prediction if enabled, None otherwise. """ # Encode input. Shape: (batch, length, num_hidden), (batch,) - source_encoded, source_encoded_lengths = self.encode(inputs, valid_length=valid_length) + source_encoded, source_encoded_lengths, att_mask = self.encode(inputs, valid_length=valid_length) predicted_output_length = self.predict_output_length(source_encoded, source_encoded_lengths, constant_length_ratio) # Decoder init states states = self.decoder.init_state_from_encoder(source_encoded, source_encoded_lengths) + nvs_pred = None + if self.nvs is not None: + nvs_pred = pt.sigmoid(self.nvs(source_encoded, source_encoded_lengths, att_mask)) - return states, predicted_output_length + return states, predicted_output_length, nvs_pred def _embed_and_encode(self, source: pt.Tensor, source_length: pt.Tensor, - target: pt.Tensor) -> Tuple[pt.Tensor, pt.Tensor, pt.Tensor, List[pt.Tensor]]: + target: pt.Tensor) -> Tuple[pt.Tensor, pt.Tensor, pt.Tensor, List[pt.Tensor], + Optional[pt.Tensor]]: """ Encode the input sequence, embed the target sequence, and initialize the decoder. Used for training. @@ -212,13 +231,20 @@ def _embed_and_encode(self, :param source: Source input data. :param source_length: Length of source inputs. :param target: Target input data. - :return: encoder outputs and lengths, target embeddings, and decoder initial states + :return: encoder outputs and lengths, target embeddings, decoder initial states, attention mask and neural + vocab selection prediction (if present, otherwise None). """ source_embed = self.embedding_source(source) target_embed = self.embedding_target(target) - source_encoded, source_encoded_length = self.encoder(source_embed, source_length) + source_encoded, source_encoded_length, att_mask = self.encoder(source_embed, source_length) states = self.decoder.init_state_from_encoder(source_encoded, source_encoded_length, target_embed) - return source_encoded, source_encoded_length, target_embed, states + nvs = None + if self.nvs is not None: + source_encoded_for_nvs = source_encoded + if self.config.neural_vocab_selection_block_loss: + source_encoded_for_nvs = source_encoded.detach() + nvs = self.nvs(source_encoded_for_nvs, source_length, att_mask) + return source_encoded, source_encoded_length, target_embed, states, nvs def decode_step(self, step_input: pt.Tensor, @@ -256,9 +282,10 @@ def forward(self, source, source_length, target, target_length): # pylint: disa # caching the encoder and embedding forward passes), turn off autograd # for the encoder and embeddings to save memory. with pt.no_grad() if self.train_decoder_only or self.forward_pass_cache_size > 0 else utils.no_context(): - source_encoded, source_encoded_length, target_embed, states = self.embed_and_encode(source, - source_length, - target) + source_encoded, source_encoded_length, target_embed, states, nvs_prediction = self.embed_and_encode( + source, + source_length, + target) target = self.decoder.decode_seq(target_embed, states=states) @@ -273,6 +300,9 @@ def forward(self, source, source_length, target, target_length): # pylint: disa # predicted_length_ratios: (batch_size,) forward_output[C.LENRATIO_NAME] = self.length_ratio(source_encoded, source_encoded_length) + if nvs_prediction is not None: + forward_output[C.NVS_PRED_NAME] = nvs_prediction + return forward_output def predict_output_length(self, diff --git a/sockeye/nvs.py b/sockeye/nvs.py new file mode 100644 index 000000000..7f00b6175 --- /dev/null +++ b/sockeye/nvs.py @@ -0,0 +1,47 @@ +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import torch as pt +from . import constants as C + + +class NeuralVocabSelection(pt.nn.Module): + def __init__( + self, model_size: int, vocab_target_size: int, + model_type: str = C.NVS_TYPE_LOGIT_MAX, + ): + super().__init__() + self.vocab_target_size = vocab_target_size + self.model_type = model_type + + self.project_vocab = pt.nn.Linear(model_size, vocab_target_size, bias=True) + + def forward(self, source_encoded: pt.Tensor, source_length: pt.Tensor, att_mask: pt.Tensor): + if self.model_type == C.NVS_TYPE_LOGIT_MAX: + # ============ + # logit max: + # ============ + bow_pred = self.project_vocab(source_encoded) + bow_pred = bow_pred.masked_fill(att_mask.unsqueeze(2), -pt.inf) + bow_pred, _ = pt.max(bow_pred, dim=1) + elif C.NVS_TYPE_EOS: + # ============ + # EOS based: + # ============ + batch_size, max_len, _ = source_encoded.size() + source_encoded = source_encoded[pt.arange(0, batch_size, dtype=pt.long), (source_length-1).long()] + bow_pred = self.project_vocab(source_encoded) + else: + raise ValueError("Unknown neural vocabulary selection type.") + + return bow_pred diff --git a/sockeye/train.py b/sockeye/train.py index 7d2ab28ef..9b5f27522 100644 --- a/sockeye/train.py +++ b/sockeye/train.py @@ -678,6 +678,8 @@ def create_model_config(args: argparse.Namespace, config_decoder=config_decoder, config_length_task=config_length_task, weight_tying_type=args.weight_tying_type, + neural_vocab_selection=args.neural_vocab_selection, + neural_vocab_selection_block_loss=args.neural_vocab_selection_block_loss, lhuc=args.lhuc is not None, dtype=C.DTYPE_FP32) return model_config @@ -724,6 +726,17 @@ def create_losses(args: argparse.Namespace, all_num_classes: List[int]) -> List[ weight=weight, output_name=C.LENRATIO_NAME, label_name=C.LENRATIO_LABEL_NAME)) + + if args.neural_vocab_selection: + bow_loss = loss.BinaryCrossEntropyBowLoss(name="bow_ce", + output_name=C.NVS_PRED_NAME, + weight=args.bow_task_weight, + pos_weight=args.bow_task_pos_weight, + num_labels=all_num_classes[0], + label_name=C.TARGET_LABEL_NAME, + metric_prefix="bow") + losses.append(bow_loss) + return losses @@ -807,7 +820,7 @@ def is_fixed(name: str) -> bool: # Any decoder layer. return not name.startswith(C.DECODER_PREFIX) if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_OUTER_LAYERS: - # First and last encoder and decoder layers. + # First and last encoder and decoder layers (this excludes output layer and NVS). first_encoder_prefix = f'{C.ENCODER_PREFIX}.layers.{0}' last_encoder_prefix = f'{C.ENCODER_PREFIX}.layers.{num_encoder_layers - 1}' first_decoder_prefix = f'{C.DECODER_PREFIX}.layers.{0}' @@ -821,7 +834,7 @@ def is_fixed(name: str) -> bool: return not (name.startswith(C.SOURCE_EMBEDDING_PREFIX) or name.startswith(C.TARGET_EMBEDDING_PREFIX)) if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_OUTPUT_PROJ: # Target output projection. - return not name.startswith(C.DEFAULT_OUTPUT_LAYER_PREFIX) + return not name.startswith(C.DEFAULT_OUTPUT_LAYER_PREFIX) and not name.startswith(C.NVS_LAYER_PREFIX) if strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_FEED_FORWARD: return not (name.endswith("ff.ff1.bias") or name.endswith("ff.ff1.weight") or name.endswith("ff.ff2.bias") or name.endswith("ff.ff2.weight")) @@ -1013,7 +1026,7 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = # https://nvidia.github.io/apex/amp.html#o2-almost-fp16-mixed-precision training_model, optimizer = apex.amp.initialize(training_model, optimizer, opt_level='O2') - logger.info('Tracing model on validation batch') + logger.info('Tracing model on a validation batch') batch = eval_iter.next().load(device=device) # pylint: disable=not-callable # When using AMP, turn on autocasting when tracing the model so that # dtypes will match during AMP training. Disable the weight cache for diff --git a/sockeye/translate.py b/sockeye/translate.py index 839b33aa7..facb3038d 100644 --- a/sockeye/translate.py +++ b/sockeye/translate.py @@ -24,7 +24,7 @@ import torch as pt -from sockeye.lexicon import TopKLexicon +from sockeye.lexicon import load_restrict_lexicon, RestrictLexicon from sockeye.log import setup_main_logger from sockeye.model import load_models from sockeye.output_handler import get_output_handler, OutputHandler @@ -78,22 +78,22 @@ def run_translate(args: argparse.Namespace): dtype=args.dtype, inference_only=True) - restrict_lexicon = None # type: Optional[Union[TopKLexicon, Dict[str, TopKLexicon]]] + restrict_lexicon = None # type: Optional[Union[RestrictLexicon, Dict[str, RestrictLexicon]]] if args.restrict_lexicon is not None: logger.info(str(args.restrict_lexicon)) if len(args.restrict_lexicon) == 1: # Single lexicon used for all inputs. - restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0]) # Handle a single arg of key:path or path (parsed as path:path) - restrict_lexicon.load(args.restrict_lexicon[0][1], k=args.restrict_lexicon_topk) + restrict_lexicon = load_restrict_lexicon(args.restrict_lexicon[0][1], source_vocabs[0], target_vocabs[0], + k=args.restrict_lexicon_topk) + logger.info(f"Loaded a single lexicon ({args.restrict_lexicon[0][0]}) that will be applied to all inputs.") else: check_condition(args.json_input, "JSON input is required when using multiple lexicons for vocabulary restriction") # Multiple lexicons with specified names restrict_lexicon = dict() for key, path in args.restrict_lexicon: - lexicon = TopKLexicon(source_vocabs[0], target_vocabs[0]) - lexicon.load(path, k=args.restrict_lexicon_topk) + lexicon = load_restrict_lexicon(path, source_vocabs[0], target_vocabs[0], k=args.restrict_lexicon_topk) restrict_lexicon[key] = lexicon brevity_penalty_weight = args.brevity_penalty_weight diff --git a/test/common.py b/test/common.py index fada0d226..42c15fb26 100644 --- a/test/common.py +++ b/test/common.py @@ -49,12 +49,15 @@ def check_train_translate(train_params: str, seed=seed) # Test equivalence of batch decoding - if 'greedy' not in translate_params: + # With neural-vocab-selection the vocabulary is determined on the batch level so that batch and non-batch outputs + # may differ. + if 'greedy' not in translate_params and 'neural-vocab-selection' not in train_params: translate_params_batch = translate_params + " --batch-size 2" test_translate_equivalence(data, translate_params_batch, compare_output=True) # Run translate with restrict-lexicon - data = run_translate_restrict(data, translate_params) + if 'neural-vocab-selection ' not in train_params: + data = run_translate_restrict(data, translate_params) test_translate_equivalence(data, translate_params, compare_output=True) @@ -66,7 +69,8 @@ def check_train_translate(train_params: str, # - translate splits up too-long sentences and translates them in sequence, invalidating the score, so skip that # - scoring requires valid translation output to compare against if '--max-input-length' not in translate_params and _translate_output_is_valid(data['test_outputs']) \ - and _translate_output_is_valid(data['test_with_target_prefix_outputs']) and 'greedy' not in translate_params: + and 'greedy' not in translate_params and 'neural-vocab-selection' not in train_params \ + and _translate_output_is_valid(data['test_with_target_prefix_outputs']): test_scoring(data, translate_params, compare_output) # Test correct prediction of target factors if enabled diff --git a/test/integration/test_seq_copy_int.py b/test/integration/test_seq_copy_int.py index c698510c4..9b588f4e6 100644 --- a/test/integration/test_seq_copy_int.py +++ b/test/integration/test_seq_copy_int.py @@ -53,6 +53,17 @@ " --checkpoint-interval 20 --optimizer adam --initial-learning-rate 0.01 --learning-rate-scheduler none", "--beam-size 2 --nbest-size 2", False, 0, 0), + # Basic transformer w/ Neural Vocabulary Selection + ("--encoder transformer --decoder {decoder}" + " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" + " --transformer-feed-forward-num-hidden 16" + " --transformer-dropout-prepost 0.1 --transformer-preprocess n --transformer-postprocess dr" + " --weight-tying-type src_trg" + " --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 0" + " --checkpoint-interval 2 --optimizer adam --initial-learning-rate 0.01" + " --neural-vocab-selection logit_max --bow-task-weight 2", + "--beam-size 2 --nbest-size 2", + False, 0, 0), # Basic transformer w/ prepared data & greedy decoding ("--encoder transformer --decoder {decoder}" " --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8" diff --git a/test/system/test_seq_copy_sys.py b/test/system/test_seq_copy_sys.py index 0d7a70f7b..d2c96ac4f 100644 --- a/test/system/test_seq_copy_sys.py +++ b/test/system/test_seq_copy_sys.py @@ -58,6 +58,17 @@ False, 1.02, 0.98), + ("Copy:transformer:transformer", + "--encoder transformer --decoder transformer" + " --max-updates 4000" + " --num-layers 2 --transformer-attention-heads 4 --transformer-model-size 32" + " --transformer-feed-forward-num-hidden 64 --num-embed 32" + " --batch-size 16 --batch-type sentence" + " --neural-vocab-selection logit_max --bow-task-weight 2" + COMMON_TRAINING_PARAMS, + "--beam-size 1 --prevent-unk", + False, + 1.02, + 0.98), ("greedy", "--encoder transformer --decoder transformer" " --max-updates 4000" @@ -124,19 +135,24 @@ def test_seq_copy(name, train_params, translate_params, use_prepared_data, perpl # compute metrics hypotheses = [json['translation'] for json in data['test_outputs']] - hypotheses_restricted = [json['translation'] for json in data['test_outputs_restricted']] bleu = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses, references=data['test_targets']) chrf = sockeye.evaluate.raw_corpus_chrf(hypotheses=hypotheses, references=data['test_targets']) - bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses_restricted, - references=data['test_targets']) + if 'test_outputs_restricted' in data: + hypotheses_restricted = [json['translation'] for json in data['test_outputs_restricted']] + bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses_restricted, + references=data['test_targets']) + else: + bleu_restrict = None logger.info("================") logger.info("test results: %s", name) logger.info("perplexity=%f, bleu=%f, bleu_restrict=%f chrf=%f", perplexity, bleu, bleu_restrict, chrf) logger.info("================\n") + assert perplexity <= perplexity_thresh assert bleu >= bleu_thresh - assert bleu_restrict >= bleu_thresh + if bleu_restrict is not None: + assert bleu_restrict >= bleu_thresh SORT_CASES = [ @@ -204,11 +220,14 @@ def test_seq_sort(name, train_params, translate_params, use_prepared_data, # compute metrics hypotheses = [json['translation'] for json in data['test_outputs']] - hypotheses_restricted = [json['translation'] for json in data['test_outputs_restricted']] bleu = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses, references=data['test_targets']) chrf = sockeye.evaluate.raw_corpus_chrf(hypotheses=hypotheses, references=data['test_targets']) - bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses_restricted, - references=data['test_targets']) + if 'test_outputs_restricted' in data: + hypotheses_restricted = [json['translation'] for json in data['test_outputs_restricted']] + bleu_restrict = sockeye.evaluate.raw_corpus_bleu(hypotheses=hypotheses_restricted, + references=data['test_targets']) + else: + bleu_restrict = None logger.info("================") logger.info("test results: %s", name) @@ -216,4 +235,5 @@ def test_seq_sort(name, train_params, translate_params, use_prepared_data, logger.info("================\n") assert perplexity <= perplexity_thresh assert bleu >= bleu_thresh - assert bleu_restrict >= bleu_thresh + if bleu_restrict is not None: + assert bleu_restrict >= bleu_thresh diff --git a/test/unit/test_arguments.py b/test/unit/test_arguments.py index 67915354b..80aa98a5b 100644 --- a/test/unit/test_arguments.py +++ b/test/unit/test_arguments.py @@ -124,7 +124,9 @@ def test_device_args(test_params, expected_params): decoder=C.TRANSFORMER_TYPE, dtype='float32', amp=False, - apex_amp=False)) + apex_amp=False, + neural_vocab_selection=None, + neural_vocab_selection_block_loss=False)) ]) def test_model_parameters(test_params, expected_params): _test_args(test_params, expected_params, arguments.add_model_parameters) @@ -160,7 +162,9 @@ def test_model_parameters(test_params, expected_params): dtype=None, prevent_unk=False, sample=None, - seed=None)), + seed=None, + nvs_thresh=0.5, + skip_nvs=False)), ]) def test_inference_args(test_params, expected_params): _test_args(test_params, expected_params, arguments.add_inference_args) @@ -218,7 +222,9 @@ def test_inference_args(test_params, expected_params): cache_last_best_params=0, cache_strategy=C.AVERAGE_BEST, cache_metric=C.PERPLEXITY, - dry_run=False)), + dry_run=False, + bow_task_pos_weight=10, + bow_task_weight=1.0)), ]) def test_training_arg(test_params, expected_params): _test_args(test_params, expected_params, arguments.add_training_args) diff --git a/test/unit/test_beam_search.py b/test/unit/test_beam_search.py index 36a7bfc70..63e3b3c2a 100644 --- a/test/unit/test_beam_search.py +++ b/test/unit/test_beam_search.py @@ -17,6 +17,7 @@ import numpy as onp import pytest import torch as pt +import numpy as np import sockeye.beam_search import sockeye.constants as C @@ -256,7 +257,8 @@ def encode_and_initialize(self, num_decode_step_calls = pt.zeros(1, dtype=pt.int) self.states = [internal_lengths, num_decode_step_calls] # TODO add nested states predicted_output_length = pt.ones(batch_size, 1) # does that work? - return self.states, predicted_output_length + nvs_prediction = None + return self.states, predicted_output_length, nvs_prediction def decode_step(self, step_input: pt.Tensor, @@ -341,3 +343,80 @@ def test_beam_search(): print('internal lengths', inference.states[0]) pt.testing.assert_allclose(r.lengths, inference.states[0].squeeze(1)) assert inference.states[1] == max_length + + +def test_get_nvs_vocab_slice_ids(): + # Batch size 2 + # Note: the first 4 tokens are special tokens (PAD, UNK etc.) + # 0 1 2 3 4 5 6 7 8 9 + nvs_prediction = pt.tensor([[0.1, 0.1, 0.1, 0.1, 0.7, 0.0, 0.8, 0.0, 0.0, 0.0], + [0.1, 0.1, 0.1, 0.1, 0.55, 0.0, 0.49, 0.05, 0.0, 0.0]]) + expected_bow = pt.tensor([0, 1, 2, 3, 4, 6, C.EOS_ID, C.EOS_ID]) + bow, output_vocab_size = sockeye.beam_search._get_nvs_vocab_slice_ids(nvs_thresh=0.5, + nvs_prediction=nvs_prediction) + assert output_vocab_size == expected_bow.shape[0] + pt.testing.assert_allclose(bow, expected_bow) + + # Batch size 1 + # 0 1 2 3 4 5 6 7 8 9 + nvs_prediction = pt.tensor([[0.1, 0.1, 0.1, 0.1, 0.7, 0.0, 0.0, 0.8, 0.0, 0.0]]) + expected_bow = pt.tensor([0, 1, 2, 3, 4, 7, C.EOS_ID, C.EOS_ID]) + bow, output_vocab_size = sockeye.beam_search._get_nvs_vocab_slice_ids(nvs_thresh=0.5, + nvs_prediction=nvs_prediction) + assert output_vocab_size == expected_bow.shape[0] + pt.testing.assert_allclose(bow, expected_bow) + + # Batch size 1 + higher thresh + # 0 1 2 3 4 5 6 7 8 9 + nvs_prediction = pt.tensor([[0.1, 0.1, 0.1, 0.1, 0.7, 0.0, 0.0, 0.8, 0.0, 0.0]]) + expected_bow = pt.tensor([0, 1, 2, 3, C.EOS_ID, C.EOS_ID, C.EOS_ID, C.EOS_ID]) + bow, output_vocab_size = sockeye.beam_search._get_nvs_vocab_slice_ids(nvs_thresh=0.9, + nvs_prediction=nvs_prediction) + assert output_vocab_size == expected_bow.shape[0] + pt.testing.assert_allclose(bow, expected_bow) + + # Batch size 2 + target prefix + # Note: the first 4 tokens are special tokens (PAD, UNK etc.) + # 0 1 2 3 4 5 6 7 8 9 + nvs_prediction = pt.tensor([[0.1, 0.1, 0.1, 0.1, 0.7, 0.0, 0.8, 0.0, 0.0, 0.0], + [0.1, 0.1, 0.1, 0.1, 0.55, 0.0, 0.49, 0.05, 0.0, 0.0]]) + target_prefix = pt.tensor([[8, 8], [8, 8]]) + expected_bow = pt.tensor([0, 1, 2, 3, 4, 6, 8, C.EOS_ID]) + bow, output_vocab_size = sockeye.beam_search._get_nvs_vocab_slice_ids(nvs_thresh=0.5, + nvs_prediction=nvs_prediction, + target_prefix=target_prefix) + assert output_vocab_size == expected_bow.shape[0] + pt.testing.assert_allclose(bow, expected_bow) + + # Batch size 2 + blocking lexicon + # Note: the first 4 tokens are special tokens (PAD, UNK etc.) + # 0 1 2 3 4 5 6 7 8 9 + nvs_prediction = pt.tensor([[0.1, 0.1, 0.1, 0.1, 0.7, 0.0, 0.8, 0.0, 0.0, 0.0], + [0.1, 0.1, 0.1, 0.1, 0.55, 0.0, 0.49, 0.05, 0.0, 0.0]]) + expected_bow = pt.tensor([0, 1, 2, 3, 4, C.EOS_ID, C.EOS_ID, C.EOS_ID]) + restrict_lexicon = sockeye.lexicon.StaticBlockLexicon( + np.array([6]) + ) + bow, output_vocab_size = sockeye.beam_search._get_nvs_vocab_slice_ids(nvs_thresh=0.5, + nvs_prediction=nvs_prediction, + restrict_lexicon=restrict_lexicon) + assert output_vocab_size == expected_bow.shape[0] + pt.testing.assert_allclose(bow, expected_bow) + + +def test_get_vocab_slice_ids_blocking(): + # test _get_vocab_slice_ids when using a blocking lexicon. + restrict_lexicon = sockeye.lexicon.StaticBlockLexicon( + np.array([3]) + ) + source_words = pt.tensor([1, 2, 3]) + vocab_slice_ids, _ = sockeye.beam_search._get_vocab_slice_ids( + restrict_lexicon=restrict_lexicon, + source_words=source_words, + eos_id=C.EOS_ID, + beam_size=5, + target_prefix=None, + output_vocab_size=6 + ) + expected_vocab_slice_ids = pt.tensor([0, 1, 2, 4, 5, C.EOS_ID, C.EOS_ID, C.EOS_ID]) + pt.testing.assert_allclose(vocab_slice_ids, expected_vocab_slice_ids) diff --git a/test/unit/test_inference.py b/test/unit/test_inference.py index 3d633dcf1..087f6de66 100644 --- a/test/unit/test_inference.py +++ b/test/unit/test_inference.py @@ -314,8 +314,8 @@ def test_make_input_from_valid_json_string_restrict_lexicon(): text = 'this is a test' translator = mock_translator() - lexicon1 = Mock(sockeye.lexicon.TopKLexicon) - lexicon2 = Mock(sockeye.lexicon.TopKLexicon) + lexicon1 = Mock(sockeye.lexicon.RestrictLexicon) + lexicon2 = Mock(sockeye.lexicon.RestrictLexicon) translator.restrict_lexicon = {'lexicon1': lexicon1, 'lexicon2': lexicon2} assert translator.restrict_lexicon['lexicon1'] is not translator.restrict_lexicon['lexicon2'] diff --git a/test/unit/test_lexicon.py b/test/unit/test_lexicon.py index 9e89b1802..263759837 100644 --- a/test/unit/test_lexicon.py +++ b/test/unit/test_lexicon.py @@ -56,15 +56,15 @@ def test_topk_lexicon(): assert np.all(lex.lex == expected_sorted) # Test lookup - trg_ids = lex.get_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int32)) + trg_ids = lex.get_allowed_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int32)) expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a", "b"]], dtype=np.int32) assert np.all(trg_ids == expected) - trg_ids = lex.get_trg_ids(np.array([[vocab["b"]]], dtype=np.int32)) + trg_ids = lex.get_allowed_trg_ids(np.array([[vocab["b"]]], dtype=np.int32)) expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["b"]], dtype=np.int32) assert np.all(trg_ids == expected) - trg_ids = lex.get_trg_ids(np.array([[vocab["c"]]], dtype=np.int32)) + trg_ids = lex.get_allowed_trg_ids(np.array([[vocab["c"]]], dtype=np.int32)) expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS], dtype=np.int32) assert np.all(trg_ids == expected) @@ -72,7 +72,7 @@ def test_topk_lexicon(): small_k = k - 1 lex.load(json_lex_path, k=small_k) assert lex.lex.shape[1] == small_k - trg_ids = lex.get_trg_ids(np.array([[vocab["a"]]], dtype=np.int32)) + trg_ids = lex.get_allowed_trg_ids(np.array([[vocab["a"]]], dtype=np.int32)) expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a"]], dtype=np.int32) assert np.all(trg_ids == expected) @@ -80,6 +80,48 @@ def test_topk_lexicon(): large_k = k + 1 lex.load(json_lex_path, k=large_k) assert lex.lex.shape[1] == k - trg_ids = lex.get_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int32)) + trg_ids = lex.get_allowed_trg_ids(np.array([[vocab["a"], vocab["c"]]], dtype=np.int32)) expected = np.array([vocab[symbol] for symbol in C.VOCAB_SYMBOLS + ["a", "b"]], dtype=np.int32) assert np.all(trg_ids == expected) + + +def test_create_block_lexicon(): + vocab = { + "test": 0, + "TeSt": 1, + "foo": 2, + "bar": 3, + + } + block_tokens = ["TeSt", "bar"] + irrelevant_src_ids = np.array([1, 2, 3, 4]) + + with TemporaryDirectory(prefix="test_create_block_lexicon.") as work_dir: + out_path = os.path.join(work_dir, "input.lex") + sockeye.lexicon.create_block_lexicon( + block_tokens, + vocab, + output_path=out_path, + lowercase=False + ) + + lexicon = sockeye.lexicon.load_restrict_lexicon(out_path) + expected = np.array([1, 3], dtype=np.int32) + assert np.all(lexicon.lex == expected) + assert np.all(lexicon.get_blocked_trg_ids() == expected) + assert np.all(lexicon.get_blocked_trg_ids(irrelevant_src_ids) == expected) + + with TemporaryDirectory(prefix="test_create_block_lexicon.") as work_dir: + out_path = os.path.join(work_dir, "input.lex") + sockeye.lexicon.create_block_lexicon( + block_tokens, + vocab, + output_path=out_path, + lowercase=True + ) + + lexicon = sockeye.lexicon.load_restrict_lexicon(out_path) + expected = np.array([0, 1, 3], dtype=np.int32) + assert np.all(lexicon.lex == expected) + assert np.all(lexicon.get_blocked_trg_ids() == expected) + assert np.all(lexicon.get_blocked_trg_ids(irrelevant_src_ids) == expected) \ No newline at end of file diff --git a/test/unit/test_loss.py b/test/unit/test_loss.py index 2f43953ec..56367dd59 100644 --- a/test/unit/test_loss.py +++ b/test/unit/test_loss.py @@ -98,6 +98,64 @@ def test_cross_entropy_loss(): pt.testing.assert_allclose(logits.grad, expected_logits_grad) +def test_label_to_bow(): + labels = pt.tensor( + [ + [1, 3], + [0, 0], + ] + ) + bow = sockeye.loss._label_to_bow(labels, num_labels=4) + expected_bow = pt.tensor([ + [0, 1, 0, 1], + [1, 0, 0, 0], + ]) + pt.testing.assert_allclose(bow, expected_bow) + + +def test_binary_cross_entropy_loss(): + vocab_size = 4 + b = sockeye.loss.BinaryCrossEntropyBowLoss( + pos_weight=1, + num_labels=vocab_size + ) + assert b.name == C.BINARY_CROSS_ENTROPY + assert b.weight == 1.0 + assert b._dtype == C.DTYPE_FP32 + assert b.output_name == C.NVS_PRED_NAME + assert b.label_name == C.TARGET_LABEL_NAME + + # batch size x num vocab + # 2 x 4 + # Only as single element will contribute to the loss + # (as all other predicitons will match the labels so the loss will be ~0) + logits = pt.tensor([[-100, 100, -100, 1], + [-100, 100, -100, 100]], dtype=pt.float32, requires_grad=True) + + # (batch_size, num_target_vocabs, num_vocab) + # (2, 1, 4) + labels = pt.tensor( + [ + [1, 3], + [1, 3], + ] + ) + batch_size = labels.shape[0] + + loss_value, loss_samples = b({C.NVS_PRED_NAME: logits, 'other_stuff': None}, + {C.TARGET_LABEL_NAME: labels, 'other_stuff': None}) + loss_value.backward() + assert loss_samples.item() == 1 # this loss returns always 1 + expected_loss = -pt.log(pt.sigmoid(pt.tensor(1))) / vocab_size / batch_size + pt.testing.assert_allclose(loss_value, expected_loss) + expected_grad = - 1/ (pt.exp(pt.tensor(1)) + 1) / vocab_size / batch_size + pt.testing.assert_allclose(logits.grad, + pt.tensor([[0.0000, 0.0000, 0.0000, expected_grad], + [0.0000, 0.0000, 0.0000, 0.0000]]) + ) + + + def test_perplexity_metric(): ppl = sockeye.loss.PerplexityMetric() assert ppl.name == C.PERPLEXITY