Skip to content

Commit

Permalink
Reduce _generate_checkpoint_and_upkeep code complexity (#783)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #783

The function _generate_checkpoint_and_upkeep is important for the checkpointing logic. But it has a linter warning indicating that it is too complex, and a there is a TODO to extract some logic into a separate function.

Let's do a small refactor to improve readability and reduce function complexity, but avoid breaking changes or regressions.

**Potential Future Changes**

Note that while doing this change, I found two small bugs that we can fix. They are documented in this Bento notebook: https://fburl.com/anp/gqoezved I did not fix any of them here to avoid having a refactor + logic changes.

Additionally, there is this user request that we can handle in this function. Again it was not modified here but we can decide what to do and change later: https://fb.workplace.com/groups/cu.training.framework.users/permalink/1147131179616886/

Reviewed By: galrotem

Differential Revision: D55881050

fbshipit-source-id: 46d0777a2dc5208763628fecdac8c12a5573e407
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Apr 12, 2024
1 parent fd78425 commit 05d1458
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 60 deletions.
45 changes: 45 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,51 @@ def test_no_assert_error_in_on_train_end(self) -> None:
callbacks=[checkpoint_cb],
)

def test_get_tracked_metric_value(self) -> None:
"""
Tests that _get_tracked_metric_value returns the correct value
"""
val_loss_unit = MyValLossUnit()

val_loss_ckpt_cb = BaseCheckpointSaver(
dirpath="checkpoint",
best_checkpoint_config=BestCheckpointConfig("val_loss", "min"),
)
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
self.assertEqual(0.01, val_loss)

# pyre-ignore
val_loss_unit.val_loss = "0.01" # Test when returned as a string
val_loss_from_s = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)
self.assertEqual(0.01, val_loss_from_s)

# pyre-ignore
val_loss_unit.val_loss = "hola" # Test weird metric value
with self.assertRaisesRegex(
RuntimeError,
(
"Unable to convert monitored metric val_loss to a float. Please ensure the value "
"can be converted to float and is not a multi-element tensor value."
),
):
val_loss = val_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)

train_loss_ckpt_cb = BaseCheckpointSaver(
dirpath="checkpoint",
best_checkpoint_config=BestCheckpointConfig("train_loss", "max"),
)
with self.assertRaisesRegex(
RuntimeError,
"Unit does not have attribute train_loss, unable to retrieve metric to checkpoint.",
):
val_loss = train_loss_ckpt_cb._get_tracked_metric_value(val_loss_unit)

ckpt_cb = BaseCheckpointSaver(
dirpath="checkpoint",
)
no_metric = ckpt_cb._get_tracked_metric_value(val_loss_unit)
self.assertIsNone(no_metric)


class MyValLossUnit(TrainUnit[Batch]):
def __init__(self) -> None:
Expand Down
131 changes: 71 additions & 60 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, cast, Iterable, List, Literal, Optional, Union

import torch.distributed as dist

from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks._checkpoint_utils import (
_delete_checkpoint,
Expand Down Expand Up @@ -197,85 +197,96 @@ def _generate_checkpoint_and_upkeep(
Returns:
True if checkpoint was successfully saved. False otherwise.
"""
unit = cast(TTrainUnit, unit)

# 1) generate checkpoint name
unit = cast(TTrainUnit, unit)
num_steps_completed = unit.train_progress.num_steps_completed
if state.entry_point == EntryPoint.FIT:
num_steps_completed += cast(
TEvalUnit, unit
).eval_progress.num_steps_completed
eval_unit = cast(TEvalUnit, unit)
num_steps_completed += eval_unit.eval_progress.num_steps_completed
epoch = unit.train_progress.num_epochs_completed
checkpoint_path = _get_save_path(self._dirpath, epoch, num_steps_completed)

# 1.5) Ensure the need to checkpoint again at the end of training
# 1.1) Make sure that last checkpoint does not already exist
if hook == "on_train_end" and self._does_checkpoint_exist(
checkpoint_path, process_group=self._process_group
):
rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
return False

# 2) handle best checkpoint config on all hooks except `on_train_end`
# TODO: isolate this logic into its own function
metric_value_f: Optional[float] = None
best_checkpoint_config = self._best_checkpoint_config
if best_checkpoint_config:
if not hasattr(unit, best_checkpoint_config.monitored_metric):
raise RuntimeError(
f"Unit does not have attribute {best_checkpoint_config.monitored_metric}, unable to retrieve metric to checkpoint."
)
# 1.2) If there is a tracked metric, add to the checkpoint path
metric_value = self._get_tracked_metric_value(unit)
if metric_value is not None:
metric_name = none_throws(self._best_checkpoint_config).monitored_metric
checkpoint_path += f"_{metric_name}={metric_value}"

metric_value = getattr(unit, best_checkpoint_config.monitored_metric)
if metric_value is not None:
try:
metric_value_f = float(metric_value)
except Exception as e:
raise RuntimeError(
f"Unable to convert monitored metric {best_checkpoint_config.monitored_metric} to a float. Please ensure the value can be converted to float and is not a multi-element tensor value."
) from e

# update checkpoint path to include the metric value info
checkpoint_path += (
f"_{best_checkpoint_config.monitored_metric}={metric_value_f}"
)

should_checkpoint = self._should_save_checkpoint(metric_value_f)
if not should_checkpoint:
# 2) Determine if checkpoint should be saved
if not self._should_save_checkpoint(metric_value):
return False

# 3) try to save checkpoint
success = self._checkpoint_impl(
state,
unit,
checkpoint_path=checkpoint_path,
hook=hook,
)
if not self._checkpoint_impl(
state, unit, checkpoint_path=checkpoint_path, hook=hook
):
return False

if success:
# remove the checkpoint if applicable
# and update the tracked list of checkpoint paths
# 4) remove the oldest/worst checkpoint if applicable
if self._should_remove_checkpoint():
self._remove_checkpoint(state)

# 5) update the tracked list of checkpoint paths
if self._best_checkpoint_config and (metric_value is not None):
metric_mode = none_throws(self._best_checkpoint_config).mode
# insert the checkpoint path at the correct index to preserve ordering
keys = [
float(os.path.basename(x).split("=")[-1]) for x in self._ckpt_dirpaths
]
if metric_mode == "min":
keys.reverse()
# Use bisect.bisect() to find the insertion point
idx = bisect.bisect(keys, metric_value)
if metric_mode == "min":
idx = len(self._ckpt_dirpaths) - idx
self._ckpt_dirpaths.insert(idx, checkpoint_path)

elif not self._best_checkpoint_config: # no metric to track
self._ckpt_dirpaths.append(checkpoint_path)

if self._should_remove_checkpoint():
self._remove_checkpoint(state)
return True

if best_checkpoint_config:
if metric_value_f:
# insert the checkpoint path at the right index to preserve ordering
keys = [
float(os.path.basename(x).split("=")[-1])
for x in self._ckpt_dirpaths
]
if best_checkpoint_config.mode == "min":
keys.reverse()
# Use bisect.bisect() to find the insertion point
idx = bisect.bisect(keys, metric_value_f)
if best_checkpoint_config.mode == "min":
idx = len(self._ckpt_dirpaths) - idx
self._ckpt_dirpaths.insert(idx, checkpoint_path)
else:
self._ckpt_dirpaths.append(checkpoint_path)
def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
"""
If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float.
Args:
unit: The training unit to look for the tracked metric in.
Returns:
The value of the tracked metric, or None if there is no best_checkpoint config defined.
Raises:
RuntimeError: If the unit does not have the attribute specified in the best_checkpoint config,
or if the value cannot be cast to a float.
"""
if not self._best_checkpoint_config:
return None

monitored_metric_name = self._best_checkpoint_config.monitored_metric
if not hasattr(unit, monitored_metric_name):
raise RuntimeError(
f"Unit does not have attribute {monitored_metric_name}, unable to retrieve metric to checkpoint."
)

metric_value_f = None
if (metric_value := getattr(unit, monitored_metric_name)) is not None:
try:
metric_value_f = float(metric_value)
except ValueError as e:
raise RuntimeError(
f"Unable to convert monitored metric {monitored_metric_name} to a float. Please ensure the value "
"can be converted to float and is not a multi-element tensor value."
) from e

return success
return metric_value_f

def on_train_start(self, state: State, unit: TTrainUnit) -> None:
# clean up the difference if surplus of checkpoints exist
Expand Down

0 comments on commit 05d1458

Please sign in to comment.