Skip to content

Commit

Permalink
Refactored run_train_translate function for integration and system te…
Browse files Browse the repository at this point in the history
…sting (#601)

Refactored run_train_translate function for integration and system testing.
- removed quiet arg
- expanded scoring tests to cover tests with prepared data
- reduced I/O in test code
  • Loading branch information
fhieber authored Jan 27, 2019
1 parent f09d42a commit a6afcaa
Show file tree
Hide file tree
Showing 9 changed files with 577 additions and 545 deletions.
9 changes: 5 additions & 4 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import os
import sys
import types
import yaml
from typing import Any, Callable, Dict, List, Tuple, Optional

import yaml

from . import constants as C
from . import data_io
from .lr_scheduler import LearningRateSchedulerFixedStep
Expand Down Expand Up @@ -259,8 +260,8 @@ def add_average_args(params):
"--output", "-o", required=True, type=str, help="File to write averaged parameters to.")
average_params.add_argument(
"--strategy",
choices=["best", "last", "lifespan"],
default="best",
choices=C.AVERAGE_CHOICES,
default=C.AVERAGE_BEST,
help="selection method. Default: %(default)s.")


Expand Down Expand Up @@ -799,7 +800,7 @@ def add_training_args(params):
train_params.add_argument('--decoder-only',
action='store_true',
help='Pre-train a decoder. This is currently for RNN decoders only. '
'Default: %(default)s.')
'Default: %(default)s.')
train_params.add_argument('--fill-up',
type=str,
default=C.FILL_UP_DEFAULT,
Expand Down
7 changes: 7 additions & 0 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,10 @@
SCORING_TYPE_LOGPROB = 'logprob'
SCORING_TYPE_DEFAULT = SCORING_TYPE_NEGLOGPROB
SCORING_TYPE_CHOICES = [SCORING_TYPE_NEGLOGPROB, SCORING_TYPE_LOGPROB]


# parameter averaging
AVERAGE_BEST = 'best'
AVERAGE_LAST = 'last'
AVERAGE_LIFESPAN = 'lifespan'
AVERAGE_CHOICES = [AVERAGE_BEST, AVERAGE_LAST, AVERAGE_LIFESPAN]
6 changes: 3 additions & 3 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,6 @@ def _make_result(self,
pass_through_dict=trans_input.pass_through_dict,
beam_histories=translation.beam_histories)
else:

nbest_target_ids = translation.nbest_translations.target_ids_list
target_tokens_list = [[self.vocab_target_inv[id] for id in ids] for ids in nbest_target_ids]
target_strings = [C.TOKEN_SEPARATOR.join(
Expand Down Expand Up @@ -2160,7 +2159,7 @@ def __init__(self, k: int, n: int, batch_size: int, context: mx.context.Context)
:param k: The size of the beam.
:param n: Sample from the top-N words in the vocab at each timestep.
:param batch_size: Number of sentences being decoded at once.
:param vocab_size: Vocabulary size.
:param context: Context for block constants.
"""
super().__init__()
self.beam_size = k
Expand All @@ -2181,7 +2180,8 @@ def hybrid_forward(self, F, scores, target_dists, finished, best_hyp_indices, ze
:param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size)
:param target_dists: The non-cumulative target distributions (ignored).
:param finished: The list of finished hypotheses.
:param offset: Array to add to the hypothesis indices for offsetting in batch decoding.
:param best_hyp_indices: Best hypothesis indices constant.
:param zeros_array: Zeros array constant.
:return: The row indices, column indices, and values of the sampled words.
"""
# Map the negative logprobs to probabilities so as to have a distribution
Expand Down
695 changes: 295 additions & 400 deletions test/common.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion test/integration/image_captioning/test_image_captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
def test_caption_random_features(train_params: str, translate_params: str):
# generate random names
source_list = [''.join(random.choice(string.ascii_uppercase) for _ in range(4)) for i in range(15)]
prefix = "tmp_caption_ramdom"
prefix = "tmp_caption_random"
use_features = True
with tmp_img_captioning_dataset(source_list,
prefix,
Expand Down
102 changes: 72 additions & 30 deletions test/integration/test_constraints_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,20 @@
# permissions and limitations under the License.

"""
Tests constraints in many forms.
Integration tests for lexical constraints.
"""
import json
import os
import sys
from typing import Dict, List, Any
from unittest.mock import patch

import pytest
import random

import sockeye.constants as C
from test.common import run_train_translate, tmp_digits_dataset
import sockeye.translate
from test.common import run_train_translate, tmp_digits_dataset, collect_translate_output_and_scores, \
_TRANSLATE_PARAMS_COMMON

_TRAIN_LINE_COUNT = 20
_DEV_LINE_COUNT = 5
Expand All @@ -28,14 +34,14 @@
_LINE_MAX_LENGTH = 9
_TEST_MAX_LENGTH = 20

ENCODER_DECODER_SETTINGS = [
TEST_CONFIGS = [
# "Vanilla" LSTM encoder-decoder with attention
("--encoder rnn --decoder rnn --num-layers 1 --rnn-cell-type lstm --rnn-num-hidden 8 --num-embed 4 "
" --rnn-attention-type mlp"
" --rnn-attention-num-hidden 8 --loss cross-entropy --optimized-metric perplexity --max-updates 2"
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01 --batch-type sentence "
" --decode-and-evaluate 0",
2, 10),
"--batch-size 2 --beam-size 10"),
# Full transformer
("--encoder transformer --decoder transformer"
" --num-layers 2 --transformer-attention-heads 2 --transformer-model-size 8 --num-embed 8"
Expand All @@ -45,17 +51,11 @@
" --weight-init-scale=3.0 --weight-init-xavier-factor-type=avg --embed-weight-init=normal"
" --batch-size 2 --max-updates 2 --batch-type sentence --decode-and-evaluate 0"
" --checkpoint-frequency 2 --optimizer adam --initial-learning-rate 0.01",
1, 10)]
"--batch-size 1 --beam-size 10")]


@pytest.mark.parametrize(
"train_params, batch_size, beam_size",
ENCODER_DECODER_SETTINGS)
def test_constraints(train_params: str,
beam_size: int,
batch_size: int):
"""Task: copy short sequences of digits"""

@pytest.mark.parametrize("train_params, translate_params", TEST_CONFIGS)
def test_constraints(train_params: str, translate_params: str):
with tmp_digits_dataset(prefix="test_constraints",
train_line_count=_TRAIN_LINE_COUNT,
train_max_length=_LINE_MAX_LENGTH,
Expand All @@ -65,21 +65,63 @@ def test_constraints(train_params: str,
test_line_count_empty=_TEST_LINE_COUNT_EMPTY,
test_max_length=_TEST_MAX_LENGTH,
sort_target=False) as data:
# train a minimal default model
data = run_train_translate(train_params=train_params, translate_params=translate_params, data=data,
max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS)

# 'constraint' = positive constraints (must appear), 'avoid' = negative constraints (must not appear)
for constraint_type in ["constraints", "avoid"]:
_test_constrained_type(constraint_type=constraint_type, data=data, translate_params=translate_params)


def _test_constrained_type(constraint_type: str, data: Dict[str, Any], translate_params: str):
constrained_inputs = _create_constrained_inputs(constraint_type, data['test_inputs'], data['test_outputs'])
new_test_source_path = os.path.join(data['work_dir'], "test_constrained.txt")
with open(new_test_source_path, 'w') as out:
for json_line in constrained_inputs:
print(json_line, file=out)
out_path_constrained = os.path.join(data['work_dir'], "out_constrained.txt")
params = "{} {} {} --json-input --output-type translation_with_score".format(
sockeye.translate.__file__,
_TRANSLATE_PARAMS_COMMON.format(model=data['model'],
input=new_test_source_path,
output=out_path_constrained),
translate_params)
with patch.object(sys, "argv", params.split()):
sockeye.translate.main()
constrained_outputs, constrained_scores = collect_translate_output_and_scores(out_path_constrained)
assert len(constrained_outputs) == len(data['test_outputs']) == len(constrained_inputs)
for json_source, constrained_out, unconstrained_out in zip(constrained_inputs,
constrained_outputs,
data['test_outputs']):
jobj = json.loads(json_source)
if jobj.get(constraint_type) is None:
# if there were no constraints, make sure the output is the same as the unconstrained output
assert constrained_out == unconstrained_out
else:
restriction = jobj[constraint_type][0]
if constraint_type == 'constraints':
# for positive constraints, ensure the constraint is in the constrained output
assert restriction in constrained_out
else:
# for negative constraints, ensure the constraints is *not* in the constrained output
assert restriction not in constrained_out

translate_params = " --batch-size {} --beam-size {}".format(batch_size, beam_size)

# Ignore return values (perplexity and BLEU) for integration test
run_train_translate(train_params=train_params,
translate_params=translate_params,
translate_params_equiv=None,
train_source_path=data['source'],
train_target_path=data['target'],
dev_source_path=data['validation_source'],
dev_target_path=data['validation_target'],
test_source_path=data['test_source'],
test_target_path=data['test_target'],
max_seq_len=_LINE_MAX_LENGTH + C.SPACE_FOR_XOS,
work_dir=data['work_dir'],
use_prepared_data=False,
restrict_lexicon=False,
use_target_constraints=True)
def _create_constrained_inputs(constraint_type: str,
translate_inputs: List[str],
translate_outputs: List[str]) -> List[str]:
constrained_inputs = [] # type: List[str]
for sentno, (source, translate_output) in enumerate(zip(translate_inputs, translate_outputs)):
target_words = translate_output.split()
target_len = len(target_words)
new_source = {'text': source}
# From the odd-numbered sentences that are not too long, create constraints. We do
# only odds to ensure we get batches with mixed constraints / lack of constraints.
if target_len > 0 and sentno % 2 == 0:
start_pos = 0
end_pos = min(target_len, 3)
constraint = ' '.join(target_words[start_pos:end_pos])
new_source[constraint_type] = [constraint]
constrained_inputs.append(json.dumps(new_source))
return constrained_inputs
Loading

0 comments on commit a6afcaa

Please sign in to comment.