From 05d1458be1a87a08662d4475d0a9c0d31592d620 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Fri, 12 Apr 2024 10:36:39 -0700 Subject: [PATCH] Reduce `_generate_checkpoint_and_upkeep` code complexity (#783) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/783 The function _generate_checkpoint_and_upkeep is important for the checkpointing logic. But it has a linter warning indicating that it is too complex, and a there is a TODO to extract some logic into a separate function. Let's do a small refactor to improve readability and reduce function complexity, but avoid breaking changes or regressions. **Potential Future Changes** Note that while doing this change, I found two small bugs that we can fix. They are documented in this Bento notebook: https://fburl.com/anp/gqoezved I did not fix any of them here to avoid having a refactor + logic changes. Additionally, there is this user request that we can handle in this function. Again it was not modified here but we can decide what to do and change later: https://fb.workplace.com/groups/cu.training.framework.users/permalink/1147131179616886/ Reviewed By: galrotem Differential Revision: D55881050 fbshipit-source-id: 46d0777a2dc5208763628fecdac8c12a5573e407 --- .../callbacks/test_base_checkpointer.py | 45 ++++++ .../framework/callbacks/base_checkpointer.py | 131 ++++++++++-------- 2 files changed, 116 insertions(+), 60 deletions(-) diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index 1b99e4b64e..fb8fea71d7 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -852,6 +852,51 @@ def test_no_assert_error_in_on_train_end(self) -> None: callbacks=[checkpoint_cb], ) + def test_get_tracked_metric_value(self) -> None: + """ + Tests that _get_tracked_metric_value returns the correct value + """ + val_loss_unit = MyValLossUnit() + + val_loss_ckpt_cb = BaseCheckpointSaver( + dirpath="checkpoint", + best_checkpoint_config=BestCheckpointConfig("val_loss", "min"), + ) + val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit) + self.assertEqual(0.01, val_loss) + + # pyre-ignore + val_loss_unit.val_loss = "0.01" # Test when returned as a string + val_loss_from_s = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit) + self.assertEqual(0.01, val_loss_from_s) + + # pyre-ignore + val_loss_unit.val_loss = "hola" # Test weird metric value + with self.assertRaisesRegex( + RuntimeError, + ( + "Unable to convert monitored metric val_loss to a float. Please ensure the value " + "can be converted to float and is not a multi-element tensor value." + ), + ): + val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit) + + train_loss_ckpt_cb = BaseCheckpointSaver( + dirpath="checkpoint", + best_checkpoint_config=BestCheckpointConfig("train_loss", "max"), + ) + with self.assertRaisesRegex( + RuntimeError, + "Unit does not have attribute train_loss, unable to retrieve metric to checkpoint.", + ): + val_loss = train_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit) + + ckpt_cb = BaseCheckpointSaver( + dirpath="checkpoint", + ) + no_metric = ckpt_cb._get_tracked_metric_value(val_loss_unit) + self.assertIsNone(no_metric) + class MyValLossUnit(TrainUnit[Batch]): def __init__(self) -> None: diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index be12a8142d..9993fbd532 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -14,7 +14,7 @@ from typing import Any, cast, Iterable, List, Literal, Optional, Union import torch.distributed as dist - +from pyre_extensions import none_throws from torchtnt.framework.callback import Callback from torchtnt.framework.callbacks._checkpoint_utils import ( _delete_checkpoint, @@ -197,85 +197,96 @@ def _generate_checkpoint_and_upkeep( Returns: True if checkpoint was successfully saved. False otherwise. """ - unit = cast(TTrainUnit, unit) - # 1) generate checkpoint name + unit = cast(TTrainUnit, unit) num_steps_completed = unit.train_progress.num_steps_completed if state.entry_point == EntryPoint.FIT: - num_steps_completed += cast( - TEvalUnit, unit - ).eval_progress.num_steps_completed + eval_unit = cast(TEvalUnit, unit) + num_steps_completed += eval_unit.eval_progress.num_steps_completed epoch = unit.train_progress.num_epochs_completed checkpoint_path = _get_save_path(self._dirpath, epoch, num_steps_completed) - # 1.5) Ensure the need to checkpoint again at the end of training + # 1.1) Make sure that last checkpoint does not already exist if hook == "on_train_end" and self._does_checkpoint_exist( checkpoint_path, process_group=self._process_group ): rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger) return False - # 2) handle best checkpoint config on all hooks except `on_train_end` - # TODO: isolate this logic into its own function - metric_value_f: Optional[float] = None - best_checkpoint_config = self._best_checkpoint_config - if best_checkpoint_config: - if not hasattr(unit, best_checkpoint_config.monitored_metric): - raise RuntimeError( - f"Unit does not have attribute {best_checkpoint_config.monitored_metric}, unable to retrieve metric to checkpoint." - ) + # 1.2) If there is a tracked metric, add to the checkpoint path + metric_value = self._get_tracked_metric_value(unit) + if metric_value is not None: + metric_name = none_throws(self._best_checkpoint_config).monitored_metric + checkpoint_path += f"_{metric_name}={metric_value}" - metric_value = getattr(unit, best_checkpoint_config.monitored_metric) - if metric_value is not None: - try: - metric_value_f = float(metric_value) - except Exception as e: - raise RuntimeError( - f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value." - ) from e - - # update checkpoint path to include the metric value info - checkpoint_path += ( - f"_{best_checkpoint_config.monitored_metric}={metric_value_f}" - ) - - should_checkpoint = self._should_save_checkpoint(metric_value_f) - if not should_checkpoint: + # 2) Determine if checkpoint should be saved + if not self._should_save_checkpoint(metric_value): return False # 3) try to save checkpoint - success = self._checkpoint_impl( - state, - unit, - checkpoint_path=checkpoint_path, - hook=hook, - ) + if not self._checkpoint_impl( + state, unit, checkpoint_path=checkpoint_path, hook=hook + ): + return False - if success: - # remove the checkpoint if applicable - # and update the tracked list of checkpoint paths + # 4) remove the oldest/worst checkpoint if applicable + if self._should_remove_checkpoint(): + self._remove_checkpoint(state) + + # 5) update the tracked list of checkpoint paths + if self._best_checkpoint_config and (metric_value is not None): + metric_mode = none_throws(self._best_checkpoint_config).mode + # insert the checkpoint path at the correct index to preserve ordering + keys = [ + float(os.path.basename(x).split("=")[-1]) for x in self._ckpt_dirpaths + ] + if metric_mode == "min": + keys.reverse() + # Use bisect.bisect() to find the insertion point + idx = bisect.bisect(keys, metric_value) + if metric_mode == "min": + idx = len(self._ckpt_dirpaths) - idx + self._ckpt_dirpaths.insert(idx, checkpoint_path) + + elif not self._best_checkpoint_config: # no metric to track + self._ckpt_dirpaths.append(checkpoint_path) - if self._should_remove_checkpoint(): - self._remove_checkpoint(state) + return True - if best_checkpoint_config: - if metric_value_f: - # insert the checkpoint path at the right index to preserve ordering - keys = [ - float(os.path.basename(x).split("=")[-1]) - for x in self._ckpt_dirpaths - ] - if best_checkpoint_config.mode == "min": - keys.reverse() - # Use bisect.bisect() to find the insertion point - idx = bisect.bisect(keys, metric_value_f) - if best_checkpoint_config.mode == "min": - idx = len(self._ckpt_dirpaths) - idx - self._ckpt_dirpaths.insert(idx, checkpoint_path) - else: - self._ckpt_dirpaths.append(checkpoint_path) + def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]: + """ + If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float. + + Args: + unit: The training unit to look for the tracked metric in. + + Returns: + The value of the tracked metric, or None if there is no best_checkpoint config defined. + + Raises: + RuntimeError: If the unit does not have the attribute specified in the best_checkpoint config, + or if the value cannot be cast to a float. + """ + if not self._best_checkpoint_config: + return None + + monitored_metric_name = self._best_checkpoint_config.monitored_metric + if not hasattr(unit, monitored_metric_name): + raise RuntimeError( + f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint." + ) + + metric_value_f = None + if (metric_value := getattr(unit, monitored_metric_name)) is not None: + try: + metric_value_f = float(metric_value) + except ValueError as e: + raise RuntimeError( + f"Unable to convert monitored metric {monitored_metric_name} to a float. Please ensure the value " + "can be converted to float and is not a multi-element tensor value." + ) from e - return success + return metric_value_f def on_train_start(self, state: State, unit: TTrainUnit) -> None: # clean up the difference if surplus of checkpoints exist