Skip to content

Commit

Permalink
Option to not reload the best training checkpoint when reducing the l…
Browse files Browse the repository at this point in the history
…earning rate (#1045)
  • Loading branch information
mjdenkowski authored Apr 28, 2022
1 parent 23ffd29 commit 63286ff
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.13]

### Added

- Added `sockeye-train` argument `--no-reload-on-learning-rate-reduce` that disables reloading the best training checkpoint when reducing the learning rate. This currently only applies to the `plateau-reduce` learning rate scheduler since other schedulers do not reload checkpoints.

## [3.1.12]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.12'
__version__ = '3.1.13'
6 changes: 6 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,12 @@ def add_training_args(params):
default=0,
help="Number of warmup steps. If set to x, linearly increases learning rate from 10%% "
"to 100%% of the initial learning rate. Default: %(default)s.")
train_params.add_argument('--no-reload-on-learning-rate-reduce',
action='store_true',
default=False,
help='Do not reload the best training checkpoint when reducing the learning rate. '
'Default: %(default)s.')


train_params.add_argument('--fixed-param-strategy',
default=None,
Expand Down
3 changes: 2 additions & 1 deletion sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,8 @@ def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] =
max_epochs=args.max_num_epochs,
max_seconds=args.max_seconds,
update_interval=args.update_interval,
stop_training_on_decoder_failure=args.stop_training_on_decoder_failure)
stop_training_on_decoder_failure=args.stop_training_on_decoder_failure,
no_reload_on_learning_rate_reduce=args.no_reload_on_learning_rate_reduce)
if trainer_config.min_epochs is not None and trainer_config.max_epochs is not None:
check_condition(trainer_config.min_epochs <= trainer_config.max_epochs,
"Minimum number of epochs must be smaller than maximum number of epochs")
Expand Down
3 changes: 2 additions & 1 deletion sockeye/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class TrainerConfig(Config):
max_seconds: Optional[int] = None
update_interval: int = 1
stop_training_on_decoder_failure: bool = False
no_reload_on_learning_rate_reduce: bool = False


class TrainState:
Expand Down Expand Up @@ -549,7 +550,7 @@ def _adjust_learning_rate(self, has_improved: bool):
lr_adjusted = scheduler.new_evaluation_result(has_improved) # type: ignore
else:
lr_adjusted = False
if lr_adjusted and not has_improved:
if lr_adjusted and not has_improved and not self.config.no_reload_on_learning_rate_reduce:
logger.info("Loading model parameters and optimizer states from best checkpoint: %d",
self.state.best_checkpoint)
if os.path.exists(self.best_params_fname):
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def test_inference_args(test_params, expected_params):
learning_rate_reduce_factor=0.9,
learning_rate_reduce_num_not_improved=8,
learning_rate_warmup=0,
no_reload_on_learning_rate_reduce=False,
fixed_param_names=[],
fixed_param_strategy=None,
decode_and_evaluate=500,
Expand Down

0 comments on commit 63286ff

Please sign in to comment.