Skip to content

Commit

Permalink
Ignore false-positive missing keys for traced modules (#1042)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjdenkowski authored Apr 12, 2022
1 parent b2d5e85 commit 30c3913
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.10]

### Fixed

- When loading parameters, SockeyeModel now ignores false positive missing parameters for traced modules. These modules use the same parameters as their original non-traced versions.

## [3.1.9]

### Changed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.9'
__version__ = '3.1.10'
5 changes: 5 additions & 0 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ def load_parameters(self,
# Earlier versions of Sockeye may have saved parameters for traced
# modules. These parameters can be safely ignored.
unexpected = [key for key in unexpected if 'traced' not in key]
# We also ignore cases where traced modules exist and appear to be
# missing parameters. These modules actually use the same parameters as
# their original non-traced versions so there are no separate parameters
# to load.
missing = [key for key in missing if 'traced' not in key]
if not allow_missing:
utils.check_condition(not missing, f"missing keys: {missing}")
if not ignore_extra:
Expand Down
7 changes: 6 additions & 1 deletion test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,10 @@ def _test_parameter_averaging(model_path: str):

def _test_checkpoint_decoder(dev_source_path: str, dev_target_path: str, model_path: str):
"""
Runs checkpoint decoder on 10% of the dev data and checks whether metric keys are present in the result dict.
Runs checkpoint decoder on 10% of the dev data and checks whether metric
keys are present in the result dict. Also checks that we can reload model
parameters after running the checkpoint decoder (case when using the
plateau-reduce scheduler).
"""
with open(dev_source_path) as dev_fd:
num_dev_sent = sum(1 for _ in dev_fd)
Expand All @@ -254,3 +257,5 @@ def _test_checkpoint_decoder(dev_source_path: str, dev_target_path: str, model_p
assert 'bleu' in cp_metrics
assert 'chrf' in cp_metrics
assert 'decode-walltime' in cp_metrics

model.load_parameters(os.path.join(model_path, C.PARAMS_BEST_NAME), device=pt.device('cpu'))

0 comments on commit 30c3913

Please sign in to comment.