From 5144d25e6af3ca45baa33eb2c5c1ce7369364c61 Mon Sep 17 00:00:00 2001 From: Felix Hieber Date: Thu, 20 Sep 2018 09:14:26 +0200 Subject: [PATCH] Update to MXNet 1.3.0 (#534) Update to MXNet 1.3.0 to make use of the following new features: - use of MXNet unravel_index & topk - use of MXNet logical operators - topk is now a HybridBlock. - Inference HybridBlocks use static_alloc and static_shape --- CHANGELOG.md | 5 +- README.md | 6 +- requirements/requirements.gpu-cu75.txt | 2 +- requirements/requirements.gpu-cu80.txt | 2 +- requirements/requirements.gpu-cu90.txt | 2 +- requirements/requirements.gpu-cu91.txt | 2 +- requirements/requirements.gpu-cu92.txt | 4 + requirements/requirements.txt | 2 +- sockeye/__init__.py | 2 +- sockeye/encoder.py | 4 +- sockeye/inference.py | 141 ++++++++++++++++++++----- sockeye/utils.py | 36 +++---- test/unit/test_inference.py | 22 +++- 13 files changed, 163 insertions(+), 67 deletions(-) create mode 100644 requirements/requirements.gpu-cu92.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 987749068..ede6b8960 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,11 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. -## [1.18.55] +## [1.18.56] ### Changed +- Update to MXNet 1.3.0.post0 + +## [1.18.55] - Renamed `contrib` to less-generic `sockeye_contrib` ## [1.18.54] diff --git a/README.md b/README.md index d7299d1be..b58c648c4 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Recent developments and changes are tracked in our [changelog](https://github.co Sockeye requires: - **Python3** -- [MXNet 1.2.1](https://github.com/apache/incubator-mxnet/tree/1.2.1) +- [MXNet 1.3.0](https://github.com/apache/incubator-mxnet/tree/1.3.0) - numpy ## Installation @@ -85,7 +85,7 @@ Depending on your version of CUDA, you can do this by running the following: > pip install sockeye --no-deps -r requirements.gpu-cu${CUDA_VERSION}.txt > rm requirements.gpu-cu${CUDA_VERSION}.txt ``` -where `${CUDA_VERSION}` can be `75` (7.5), `80` (8.0), `90` (9.0), or `91` (9.1). +where `${CUDA_VERSION}` can be `75` (7.5), `80` (8.0), `90` (9.0), `91` (9.1), or `92` (9.2). ### Or: From Source @@ -108,7 +108,7 @@ running the following: > pip install -r requirements/requirements.gpu-cu${CUDA_VERSION}.txt > pip install . ``` -where `${CUDA_VERSION}` can be `75` (7.5), `80` (8.0), `90` (9.0), or `91` (9.1). +where `${CUDA_VERSION}` can be `75` (7.5), `80` (8.0), `90` (9.0), `91` (9.1), or `92` (9.2). ### Optional dependencies In order to write training statistics to a Tensorboard event file for visualization, you can optionally install mxboard diff --git a/requirements/requirements.gpu-cu75.txt b/requirements/requirements.gpu-cu75.txt index 9eba7b488..48a9bc0e7 100644 --- a/requirements/requirements.gpu-cu75.txt +++ b/requirements/requirements.gpu-cu75.txt @@ -1,4 +1,4 @@ pyyaml==3.12 -mxnet-cu75mkl==1.2.1 +mxnet-cu75mkl==1.3.0.post0 numpy>=1.14 typing diff --git a/requirements/requirements.gpu-cu80.txt b/requirements/requirements.gpu-cu80.txt index db17c2d46..3414a0aec 100644 --- a/requirements/requirements.gpu-cu80.txt +++ b/requirements/requirements.gpu-cu80.txt @@ -1,4 +1,4 @@ pyyaml==3.12 -mxnet-cu80mkl==1.2.1 +mxnet-cu80mkl==1.3.0.post0 numpy>=1.14 typing diff --git a/requirements/requirements.gpu-cu90.txt b/requirements/requirements.gpu-cu90.txt index 45e7cdc48..79306829a 100644 --- a/requirements/requirements.gpu-cu90.txt +++ b/requirements/requirements.gpu-cu90.txt @@ -1,4 +1,4 @@ pyyaml==3.12 -mxnet-cu90mkl==1.2.1 +mxnet-cu90mkl==1.3.0.post0 numpy>=1.14 typing diff --git a/requirements/requirements.gpu-cu91.txt b/requirements/requirements.gpu-cu91.txt index 58100e8fe..7ea2e8bd9 100644 --- a/requirements/requirements.gpu-cu91.txt +++ b/requirements/requirements.gpu-cu91.txt @@ -1,4 +1,4 @@ pyyaml==3.12 -mxnet-cu91mkl==1.2.1 +mxnet-cu91mkl==1.3.0.post0 numpy>=1.14 typing diff --git a/requirements/requirements.gpu-cu92.txt b/requirements/requirements.gpu-cu92.txt new file mode 100644 index 000000000..e5688d647 --- /dev/null +++ b/requirements/requirements.gpu-cu92.txt @@ -0,0 +1,4 @@ +pyyaml==3.12 +mxnet-cu92mkl==1.3.0.post0 +numpy>=1.14 +typing diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 808141c42..2101f604d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,4 +1,4 @@ pyyaml==3.12 -mxnet-mkl==1.2.1 +mxnet-mkl==1.3.0.post0 numpy>=1.14 typing diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 65009d094..3a4f94484 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '1.18.55' +__version__ = '1.18.56' diff --git a/sockeye/encoder.py b/sockeye/encoder.py index 627d7efd5..95c313bb6 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -1243,8 +1243,8 @@ def encode(self, transform = mx.sym.Dropout(data=transform, p=self.dropout) # Connection seg_embedding = gate * transform + (1 - gate) * seg_embedding - # (batch_size, seq_len/stride, outut_dim) aka - # (batch_size, encoded_seq_len, num_segment_emded) + # (batch_size, seq_len/stride, output_dim) aka + # (batch_size, encoded_seq_len, num_segment_embed) seg_embedding = mx.sym.Reshape(data=seg_embedding, shape=(-1, encoded_seq_len, self.output_dim)) diff --git a/sockeye/inference.py b/sockeye/inference.py index 72bdb7d56..d3918655c 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -1052,32 +1052,46 @@ def __init__(self, self._update_scores = UpdateScores() self._update_scores.initialize(ctx=self.context) - self._update_scores.hybridize() + self._update_scores.hybridize(static_alloc=True, static_shape=True) - # topk function used in beam search - if self.skip_topk: - self._top = partial(utils.top1, - offset=self.offset) + # Vocabulary selection leads to different vocabulary sizes across requests. Hence, we cannot use a + # statically-shaped HybridBlock for the topk operation in this case; resorting to imperative topk + # function in this case. + if self.restrict_lexicon: + if self.skip_topk: + self._top = partial(utils.top1, offset=self.offset) # type: Callable + else: + self._top = partial(utils.topk, + k=self.beam_size, + offset=self.offset, + use_mxnet_topk=True) # type: Callable else: - self._top = partial(utils.topk, - k=self.beam_size, - batch_size=self.batch_size, - offset=self.offset, - use_mxnet_topk=self.context != mx.cpu()) # MXNet implementation is faster on GPUs + if self.skip_topk: + self._top = Top1(k=self.beam_size, + batch_size=self.batch_size) # type: mx.gluon.HybridBlock + self._top.initialize(ctx=self.context) + self._top.hybridize(static_alloc=True, static_shape=True) + else: + self._top = TopK(k=self.beam_size, + batch_size=self.batch_size, + vocab_size=len(self.vocab_target)) # type: mx.gluon.HybridBlock + self._top.initialize(ctx=self.context) + self._top.hybridize(static_alloc=True, static_shape=True) self._sort_by_index = SortByIndex() self._sort_by_index.initialize(ctx=self.context) - self._sort_by_index.hybridize() + self._sort_by_index.hybridize(static_alloc=True, static_shape=True) self._update_finished = NormalizeAndUpdateFinished(pad_id=C.PAD_ID, eos_id=self.vocab_target[C.EOS_SYMBOL], length_penalty_alpha=self.length_penalty.alpha, length_penalty_beta=self.length_penalty.beta) self._update_finished.initialize(ctx=self.context) - self._update_finished.hybridize() + self._update_finished.hybridize(static_alloc=True, static_shape=True) + self._prune_hyps = PruneHypotheses(threshold=self.beam_prune, beam_size=self.beam_size) self._prune_hyps.initialize(ctx=self.context) - self._prune_hyps.hybridize() + self._prune_hyps.hybridize(static_alloc=True, static_shape=True) self.global_avoid_trie = None if avoid_list is not None: @@ -1663,11 +1677,8 @@ def _beam_search(self, # (9) Sort the hypotheses within each sentence (normalization for finished hyps may have unsorted them). folded_accumulated_scores = scores_accumulated.reshape((self.batch_size, self.beam_size * scores_accumulated.shape[-1])) - indices = mx.nd.argsort(folded_accumulated_scores, axis=1) - best_hyp_indices = mx.nd.array(np.unravel_index(indices.astype(np.int32).asnumpy().ravel(), - scores_accumulated.shape), - dtype='int32', - ctx=self.offset.context)[0] + self.offset + indices = mx.nd.cast(mx.nd.argsort(folded_accumulated_scores, axis=1), dtype='int32').reshape((-1,)) + best_hyp_indices, _ = mx.nd.unravel_index(indices, scores_accumulated.shape) + self.offset best_hyp_indices_list.append(best_hyp_indices) lengths = lengths.take(best_hyp_indices) scores_accumulated = scores_accumulated.take(best_hyp_indices) @@ -1844,6 +1855,82 @@ def hybrid_forward(self, F, indices, *args): return [F.take(arg, indices) for arg in args] +class TopK(mx.gluon.HybridBlock): + """ + A HybridBlock for a statically-shaped batch-wise topk operation. + """ + + def __init__(self, k: int, batch_size: int, vocab_size: int) -> None: + """ + :param k: The number of smallest scores to return. + :param batch_size: Number of sentences being decoded at once. + :param vocab_size: Vocabulary size. + """ + super().__init__() + self.k = k + self.batch_size = batch_size + self.vocab_size = vocab_size + with self.name_scope(): + offset = mx.nd.repeat(mx.nd.arange(0, batch_size * k, k, dtype='int32'), k) + self.offset = self.params.get_constant(name='offset', value=offset) + + def hybrid_forward(self, F, scores, offset): + """ + Get the lowest k elements per sentence from a `scores` matrix. + + :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size) + :param offset: Array to add to the hypothesis indices for offsetting in batch decoding. + :return: The row indices, column indices and values of the k smallest items in matrix. + """ + folded_scores = F.reshape(scores, shape=(self.batch_size, self.k * self.vocab_size)) + values, indices = F.topk(folded_scores, axis=1, k=self.k, ret_typ='both', is_ascend=True) + indices = F.reshape(F.cast(indices, 'int32'), shape=(-1,)) + unraveled = F.unravel_index(indices, shape=(self.batch_size * self.k, self.vocab_size)) + best_hyp_indices, best_word_indices = F.split(unraveled, axis=0, num_outputs=2, squeeze_axis=True) + best_hyp_indices = best_hyp_indices + offset + values = F.reshape(values, shape=(-1, 1)) + return best_hyp_indices, best_word_indices, values + + +class Top1(mx.gluon.HybridBlock): + """ + A HybridBlock for a statically-shaped batch-wise first-best operation. + + Get the single lowest element per sentence from a `scores` matrix. Expects that + beam size is 1, for greedy decoding. + + NOTE(mathmu): The current implementation of argmin in MXNet much slower than topk with k=1. + """ + def __init__(self, k: int, batch_size: int) -> None: + """ + :param k: The number of smallest scores to return. + :param batch_size: Number of sentences being decoded at once. + :param vocab_size: Vocabulary size. + """ + super().__init__() + with self.name_scope(): + offset = mx.nd.repeat(mx.nd.arange(0, batch_size * k, k, dtype='int32'), k) + self.offset = self.params.get_constant(name='offset', value=offset) + + def hybrid_forward(self, F, scores, offset): + """ + Get the single lowest element per sentence from a `scores` matrix. Expects that + beam size is 1, for greedy decoding. + + :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size) + :param offset: Array to add to the hypothesis indices for offsetting in batch decoding. + :return: The row indices, column indices and values of the smallest items in matrix. + """ + best_word_indices = F.cast(F.argmin(scores, axis=1), dtype='int32') + values = F.pick(scores, best_word_indices, axis=1) + values = F.reshape(values, shape=(-1, 1)) + + # for top1, the best hyp indices are equal to the plain offset + best_hyp_indices = offset + + return best_hyp_indices, best_word_indices, values + + class NormalizeAndUpdateFinished(mx.gluon.HybridBlock): """ A HybridBlock for normalizing newly finished hypotheses scores with LengthPenalty. @@ -1860,8 +1947,8 @@ def __init__(self, pad_id: int, self.length_penalty = LengthPenalty(alpha=length_penalty_alpha, beta=length_penalty_beta) def hybrid_forward(self, F, best_word_indices, max_output_lengths, finished, scores_accumulated, lengths): - all_finished = ((best_word_indices == self.pad_id) + (best_word_indices == self.eos_id)) - newly_finished = all_finished - finished + all_finished = F.broadcast_logical_or(best_word_indices == self.pad_id, best_word_indices == self.eos_id) + newly_finished = F.broadcast_logical_xor(all_finished, finished) scores_accumulated = F.where(newly_finished, scores_accumulated / self.length_penalty(lengths), scores_accumulated) @@ -1874,11 +1961,9 @@ def hybrid_forward(self, F, best_word_indices, max_output_lengths, finished, sco # - extended with , or # - extended with , or # - at their maximum length. - finished = F.clip( - (best_word_indices == self.pad_id) + - (best_word_indices == self.eos_id) + - (F.cast(F.reshape(lengths, shape=(-1,)), 'int32') >= max_output_lengths), - a_min=0, a_max=1) + finished = F.broadcast_logical_or(F.broadcast_logical_or(best_word_indices == self.pad_id, + best_word_indices == self.eos_id), + (F.cast(F.reshape(lengths, shape=(-1,)), 'int32') >= max_output_lengths)) return finished, scores_accumulated, lengths @@ -1901,10 +1986,8 @@ def hybrid_forward(self, F, scores, finished, inactive, scores_accumulated, inf_ # infinity otherwise. scores = F.broadcast_add(scores, scores_accumulated) # pylint: disable=invalid-sequence-index - pad_id_scores = F.where(F.clip(finished - inactive, 0, 1), - scores_accumulated, - inf_array) + pad_id_scores = F.where(F.broadcast_logical_and(finished, F.logical_not(inactive)), scores_accumulated, inf_array) # pad_dist. Shape: (batch*beam, vocab_size) pad_dist = F.concat(pad_id_scores, pad_dist) - scores = F.where(finished + inactive, pad_dist, scores) + scores = F.where(F.broadcast_logical_or(finished, inactive), pad_dist, scores) return scores diff --git a/sockeye/utils.py b/sockeye/utils.py index 1a0b31311..e6d2a2ae3 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -33,8 +33,8 @@ import mxnet as mx import numpy as np -from sockeye import __version__, constants as C -from sockeye.log import log_sockeye_version, log_mxnet_version +from . import __version__, constants as C +from .log import log_sockeye_version, log_mxnet_version logger = logging.getLogger(__name__) @@ -276,7 +276,6 @@ def top1(scores: mx.nd.NDArray, def topk(scores: mx.nd.NDArray, k: int, - batch_size: int, offset: mx.nd.NDArray, use_mxnet_topk: bool) -> Tuple[mx.nd.NDArray, mx.nd.NDArray, mx.nd.NDArray]: """ @@ -284,21 +283,19 @@ def topk(scores: mx.nd.NDArray, :param scores: Vocabulary scores for the next beam step. (batch_size * beam_size, target_vocabulary_size) :param k: The number of smallest scores to return. - :param batch_size: Number of sentences being decoded at once. :param offset: Array to add to the hypothesis indices for offsetting in batch decoding. :param use_mxnet_topk: True to use the mxnet implementation or False to use the numpy one. :return: The row indices, column indices and values of the k smallest items in matrix. """ # (batch_size, beam_size * target_vocab_size) - folded_scores = scores.reshape((batch_size, k * scores.shape[-1])) + folded_scores = scores.reshape((-1, k * scores.shape[-1])) + batch_size = folded_scores.shape[0] if use_mxnet_topk: # pylint: disable=unbalanced-tuple-unpacking values, indices = mx.nd.topk(folded_scores, axis=1, k=k, ret_typ='both', is_ascend=True) - best_hyp_indices, best_word_indices = mx.nd.array(np.unravel_index(indices.astype(np.int32).asnumpy().ravel(), - scores.shape), - dtype='int32', - ctx=scores.context) + indices = mx.nd.cast(indices, 'int32').reshape((-1,)) + best_hyp_indices, best_word_indices = mx.nd.unravel_index(indices, scores.shape) else: folded_scores = folded_scores.asnumpy() @@ -455,14 +452,12 @@ def average_arrays(arrays: List[mx.nd.NDArray]) -> mx.nd.NDArray: :param arrays: A list of NDArrays with the same shape that will be averaged. :return: The average of the NDArrays in the same context as arrays[0]. """ + if not arrays: + raise ValueError("arrays is empty.") if len(arrays) == 1: return arrays[0] check_condition(all(arrays[0].shape == a.shape for a in arrays), "nd array shapes do not match") - new_array = mx.nd.zeros(arrays[0].shape, dtype=arrays[0].dtype, ctx=arrays[0].context) - for a in arrays: - new_array += a.as_in_context(new_array.context) - new_array /= len(arrays) - return new_array + return mx.nd.add_n(*arrays) / len(arrays) def get_num_gpus() -> int: @@ -471,14 +466,7 @@ def get_num_gpus() -> int: :return: The number of GPUs on the system. """ - # TODO (domhant): Switch to mx.context.num_gpus() with mxnet version 1.3 - for device_id in itertools.count(): - try: - mx.nd.zeros((1,), ctx=mx.gpu(device_id)) - except mx.MXNetError: - return device_id - # Note: Return statement to make mypy happy, the for loop is infinite, so an exception is the only way out. - return device_id + 1 + return mx.context.num_gpus() def get_gpu_memory_usage(ctx: List[mx.context.Context]) -> Dict[int, Tuple[int, int]]: @@ -852,8 +840,8 @@ def infer_type(self, in_type): def create_operator(self, ctx, shapes, dtypes): return PrintValue(self.print_name, - print_grad=self.print_grad, - use_logger=self.use_logger) + print_grad=str(self.print_grad), + use_logger=str(self.use_logger)) def grouper(iterable: Iterable, size: int) -> Iterable: diff --git a/test/unit/test_inference.py b/test/unit/test_inference.py index d3612c471..ad76bd8f5 100644 --- a/test/unit/test_inference.py +++ b/test/unit/test_inference.py @@ -387,17 +387,34 @@ def test_topk_func(batch_size, beam_size, target_vocab_size): # offset for batch sizes > 1 offset = mx.nd.array(np.repeat(np.arange(0, batch_size * beam_size, beam_size), beam_size), dtype='int32') - np_hyp, np_word, np_values = sockeye.utils.topk(scores, k=beam_size, batch_size=batch_size, + np_hyp, np_word, np_values = sockeye.utils.topk(scores, k=beam_size, offset=offset, use_mxnet_topk=False) np_hyp, np_word, np_values = np_hyp.asnumpy(), np_word.asnumpy(), np_values.asnumpy() - mx_hyp, mx_word, mx_values = sockeye.utils.topk(scores, k=beam_size, batch_size=batch_size, + mx_hyp, mx_word, mx_values = sockeye.utils.topk(scores, k=beam_size, offset=offset, use_mxnet_topk=True) mx_hyp, mx_word, mx_values = mx_hyp.asnumpy(), mx_word.asnumpy(), mx_values.asnumpy() assert all(mx_hyp == np_hyp) assert all(mx_word == np_word) assert all(mx_values == np_values) + topk = sockeye.inference.TopK(k=beam_size, batch_size=batch_size, vocab_size=target_vocab_size) + topk.initialize() + assert all(topk.offset.data() == offset) + + mx_hyp, mx_word, mx_values = topk(scores) + mx_hyp, mx_word, mx_values = mx_hyp.asnumpy(), mx_word.asnumpy(), mx_values.asnumpy() + assert all(mx_hyp == np_hyp) + assert all(mx_word == np_word) + assert all(mx_values == np_values) + + topk.hybridize() + mx_hyp, mx_word, mx_values = topk(scores) + mx_hyp, mx_word, mx_values = mx_hyp.asnumpy(), mx_word.asnumpy(), mx_values.asnumpy() + assert all(mx_hyp == np_hyp) + assert all(mx_word == np_word) + assert all(mx_values == np_values) + def test_get_best_word_indeces_for_kth_hypotheses(): # data @@ -426,6 +443,7 @@ def test_get_best_word_indeces_for_kth_hypotheses(): assert result.shape == expected_indices.shape assert (result == expected_indices).all() + @pytest.mark.parametrize("raw_constraints, beam_histories, expected_best_ids, expected_best_indices", [([[], [], [], []], [None, None], np.array([0, 2], dtype='int32'), np.array([[1, 1, 1], [3, 3, 3]], dtype='int32')), ([[[1]], [], [[3]], []], [None, None], np.array([1, 3], dtype='int32'), np.array([[1, 0, 0], [3, 2, 2]], dtype='int32'))