Skip to content

Commit

Permalink
Clarify use of Translator.batch_size in code (#1033)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber authored Apr 5, 2022
1 parent 1ff2a1c commit 71914bb
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.9]

### Changed

- Clarified usage of `batch_size` in Translator code.

## [3.1.8]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.8'
__version__ = '3.1.9'
2 changes: 1 addition & 1 deletion sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ def _get_best_translations(self, result: SearchResult) -> List[Translation]:
batch_size = best_hyp_indices.shape[0] // self.beam_size
nbest_translations = [] # type: List[List[Translation]]
reference_lengths = estimated_reference_lengths \
if estimated_reference_lengths is not None else np.zeros((self.batch_size * self.beam_size, 1))
if estimated_reference_lengths is not None else np.zeros((batch_size * self.beam_size, 1))
for n in range(0, self.nbest_size):

# Initialize the best_ids to the first item in each batch, plus current nbest index
Expand Down
5 changes: 2 additions & 3 deletions sockeye/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def read_and_translate(translator: inference.Translator,
:param input_factors: Optional list of paths to files that contain source factors.
:param input_is_json: Whether the input is in json format.
"""
batch_size = translator.max_batch_size
if chunk_size is None:
if translator.max_batch_size == 1:
# No batching, therefore there is not need to read segments in chunks.
Expand All @@ -222,8 +221,8 @@ def read_and_translate(translator: inference.Translator,
else:
if chunk_size < translator.max_batch_size:
logger.warning("You specified a chunk size (%d) smaller than the max batch size (%d). This will lead to "
"a reduction in translation speed. Consider choosing a larger chunk size." % (chunk_size,
batch_size))
"a reduction in translation speed. Consider choosing a larger chunk size.",
chunk_size, translator.max_batch_size)

logger.info("Translating...")

Expand Down

0 comments on commit 71914bb

Please sign in to comment.