Skip to content

Commit

Permalink
Generate predict/evaluate checkpoints in BaseCheckpointer (pytorch#914)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#914

Reviewed By: JKSenthil

Differential Revision: D63013008
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 9, 2024
1 parent 7bfdee4 commit d54fe58
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 23 deletions.
84 changes: 84 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Batch,
DummyAutoUnit,
DummyFitUnit,
DummyPredictUnit,
DummyTrainUnit,
generate_random_dataloader,
get_dummy_fit_state,
Expand All @@ -35,7 +36,9 @@
)
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.lambda_callback import Lambda
from torchtnt.framework.evaluate import evaluate
from torchtnt.framework.fit import fit
from torchtnt.framework.predict import predict
from torchtnt.framework.state import ActivePhase, State

from torchtnt.framework.train import train
Expand All @@ -57,7 +60,9 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_steps: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
save_every_n_predict_steps: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
Expand All @@ -66,7 +71,9 @@ def __init__(
dirpath,
save_every_n_train_steps=save_every_n_train_steps,
save_every_n_epochs=save_every_n_epochs,
save_every_n_eval_steps=save_every_n_eval_steps,
save_every_n_eval_epochs=save_every_n_eval_epochs,
save_every_n_predict_steps=save_every_n_predict_steps,
keep_last_n_checkpoints=keep_last_n_checkpoints,
best_checkpoint_config=best_checkpoint_config,
process_group=process_group,
Expand Down Expand Up @@ -243,6 +250,83 @@ def test_save_fit_entrypoint(self) -> None:
checkpointer._latest_checkpoint_path,
)

@patch.object(BaseCheckpointSaver, "_checkpoint_impl")
def test_save_eval_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None:
my_unit = DummyFitUnit(input_dim=2)
with tempfile.TemporaryDirectory() as temp_dir:
checkpointer = BaseCheckpointSaver(
temp_dir,
save_every_n_eval_steps=2,
best_checkpoint_config=BestCheckpointConfig(
monitored_metric="val_loss", mode="min"
),
keep_last_n_checkpoints=1,
)

ckpt_container: List[str] = []

def _checkpoint_impl_side_effect(
state: State, unit: AppStateMixin, checkpoint_id: str, hook: str
) -> bool:
ckpt_container.append(checkpoint_id)
return True

mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect

eval_dataloader = generate_random_dataloader(10, 2, 1)

warning_container: List[str] = []
with patch(
"torchtnt.framework.callbacks.base_checkpointer.logging.Logger.warning",
side_effect=warning_container.append,
):
evaluate(my_unit, eval_dataloader, callbacks=[checkpointer])

# Verify that checkpoint optimality tracking was disabled
self.assertIn(
"Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints.",
warning_container,
)
self.assertIn(
"Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints.",
warning_container,
)

# Make sure that the correct checkpoints were saved, without tracked metrics
expected_ckpts = [
f"{temp_dir}/epoch_0_eval_step_{i*2}" for i in range(1, 6)
]
self.assertEqual(ckpt_container, expected_ckpts)

@patch.object(BaseCheckpointSaver, "_checkpoint_impl")
def test_save_predict_entrypoint(self, mock_checkpoint_impl: MagicMock) -> None:
my_unit = DummyPredictUnit(input_dim=2)
with tempfile.TemporaryDirectory() as temp_dir:
checkpointer = BaseCheckpointSaver(
temp_dir,
save_every_n_predict_steps=1,
)

ckpt_container: List[str] = []

def _checkpoint_impl_side_effect(
state: State, unit: AppStateMixin, checkpoint_id: str, hook: str
) -> bool:
ckpt_container.append(checkpoint_id)
return True

mock_checkpoint_impl.side_effect = _checkpoint_impl_side_effect

predict_dataloader = generate_random_dataloader(10, 2, 1)

predict(my_unit, predict_dataloader, callbacks=[checkpointer])

# Make sure that the correct checkpoints were saved
expected_ckpts = [
f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11)
]
self.assertEqual(ckpt_container, expected_ckpts)

@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)
def test_restore_from_latest(self, mock_stdout: MagicMock) -> None:
my_unit = DummyTrainUnit(input_dim=2)
Expand Down
123 changes: 101 additions & 22 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
import torch.distributed as dist
from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping
from torchtnt.framework.callbacks._checkpoint_utils import (
_get_epoch,
_get_step_phase_mapping,
)
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
from torchtnt.framework.state import ActivePhase, EntryPoint, State
from torchtnt.framework.unit import (
AppStateMixin,
TEvalUnit,
TPredictUnit,
TTrainData,
TTrainUnit,
)
from torchtnt.utils.checkpoint import (
BestCheckpointConfig,
CheckpointManager,
Expand Down Expand Up @@ -51,8 +60,11 @@ class BaseCheckpointer(Callback, metaclass=abc.ABCMeta):
save_every_n_train_steps: Frequency of steps with which to save checkpoints during the train epoch. If None, no intra-epoch checkpoints are generated.
save_every_n_epochs: Frequency of epochs with which to save checkpoints during training. If None, no end-of-epoch checkpoints are generated.
save_every_n_eval_epochs: Frequency of evaluation epochs with which to save checkpoints during training. Use this if wanting to save checkpoints after every eval epoch during fit.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint.
save_every_n_eval_steps: Frequency of evaluation steps with which to save checkpoints during training. Use this if wanting to save checkpoints during evaluate.
save_every_n_predict_steps: Frequency of prediction steps with which to save checkpoints during training. Use this if wanting to save checkpoints during using predict entrypoint.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted
to clean the difference. If best checkpoint config is enabled, this param will manage the top n checkpoints instead. Only supported for train or fit entrypoints.
best_checkpoint_config: Configuration for saving the best checkpoint based on a monitored metric. The metric is read off the attribute of the unit prior to checkpoint. This param is ignored if not in train or fit entrypoints.
process_group: The process group on which the ranks will communicate on. If the process group is not gloo-based, a new gloo-based process group will be created.
Note:
Expand All @@ -78,6 +90,8 @@ def __init__(
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
save_every_n_eval_epochs: Optional[int] = None,
save_every_n_eval_steps: Optional[int] = None,
save_every_n_predict_steps: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
best_checkpoint_config: Optional[BestCheckpointConfig] = None,
process_group: Optional[dist.ProcessGroup] = None,
Expand All @@ -90,12 +104,23 @@ def __init__(
raise ValueError(
f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}"
)
if save_every_n_eval_steps is not None and save_every_n_eval_steps <= 0:
raise ValueError(
f"Invalid value passed for save_every_n_eval_steps. Expected to receive either None or positive number, but received {save_every_n_eval_steps}"
)
if save_every_n_eval_epochs is not None and save_every_n_eval_epochs <= 0:
raise ValueError(
f"Invalid value passed for save_every_n_eval_epochs. Expected to receive either None or positive number, but received {save_every_n_eval_epochs}"
)
if save_every_n_predict_steps is not None and save_every_n_predict_steps <= 0:
raise ValueError(
f"Invalid value passed for save_every_n_predict_steps. Expected to receive either None or positive number, but received {save_every_n_predict_steps}"
)
if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0:
raise ValueError(
f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}"
)

self._best_checkpoint_config = best_checkpoint_config
if best_checkpoint_config and best_checkpoint_config.mode not in {"min", "max"}:
raise ValueError(
f"Invalid value passed for best_checkpoint_config.mode. Expected to receive 'min' or 'max', but received {best_checkpoint_config.mode}"
Expand All @@ -104,7 +129,10 @@ def __init__(
self._save_every_n_train_steps = save_every_n_train_steps
self._save_every_n_epochs = save_every_n_epochs
self._save_every_n_eval_epochs = save_every_n_eval_epochs
self._save_every_n_eval_steps = save_every_n_eval_steps
self._save_every_n_predict_steps = save_every_n_predict_steps
self._keep_last_n_checkpoints = keep_last_n_checkpoints
self._best_checkpoint_config = best_checkpoint_config

self._process_group: Optional[dist.ProcessGroup] = None
self._setup_gloo_pg(process_group)
Expand Down Expand Up @@ -147,7 +175,7 @@ def dirpath(self) -> str:
return self._checkpoint_manager.dirpath

def _generate_checkpoint_and_upkeep(
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str
) -> bool:
"""
Implementation for saving checkpoint while taking care of checkpoint
Expand All @@ -162,11 +190,16 @@ def _generate_checkpoint_and_upkeep(
True if checkpoint was successfully saved. False otherwise.
"""
# 1) generate checkpoint name
epoch = cast(TTrainUnit, unit).train_progress.num_epochs_completed
epoch = _get_epoch(state, unit)
step_mapping = _get_step_phase_mapping(state, unit)

# 1.1) append metric data only for train checkpoints, if best_checkpoint_config is defined
metric_data: Optional[MetricData] = None
if metric_value := self._get_tracked_metric_value(unit):
if (
self._best_checkpoint_config
and state.active_phase == ActivePhase.TRAIN
and (metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)))
):
metric_data = MetricData(
name=none_throws(self._best_checkpoint_config).monitored_metric,
value=metric_value,
Expand All @@ -179,7 +212,8 @@ def _generate_checkpoint_and_upkeep(
process_group=self._process_group,
)

# 2) Determine if we should save checkpoint
# 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
# since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
return False

Expand Down Expand Up @@ -222,9 +256,7 @@ def _generate_checkpoint_and_upkeep(

return True

def _get_tracked_metric_value(
self, unit: Union[TTrainUnit, TEvalUnit]
) -> Optional[float]:
def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
"""
If the checkpointer has a tracked metric, look the value in the unit using reflection, and cast to float.
Expand Down Expand Up @@ -271,33 +303,80 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None:

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
num_steps_completed = unit.train_progress.num_steps_completed
save_every_n_train_steps = self._save_every_n_train_steps
if (
save_every_n_train_steps is None
or num_steps_completed % save_every_n_train_steps != 0
not self._save_every_n_train_steps
or num_steps_completed % self._save_every_n_train_steps != 0
):
return

self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_step_end")

def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
epoch = unit.train_progress.num_epochs_completed
save_every_n_epochs = self._save_every_n_epochs
if save_every_n_epochs is None or epoch % save_every_n_epochs != 0:
if not self._save_every_n_epochs or epoch % self._save_every_n_epochs != 0:
return

self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_epoch_end")

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end")

def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
if state.entry_point == EntryPoint.EVALUATE:
self._disable_ckpt_optimality_tracking()

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
num_steps_completed = unit.eval_progress.num_steps_completed
if (
not self._save_every_n_eval_steps
or num_steps_completed % self._save_every_n_eval_steps != 0
):
return

self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_step_end")

def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None:
epoch = unit.eval_progress.num_epochs_completed
save_every_n_eval_epochs = self._save_every_n_eval_epochs
if save_every_n_eval_epochs is None or epoch % save_every_n_eval_epochs != 0:
if (
not self._save_every_n_eval_epochs
or epoch % self._save_every_n_eval_epochs != 0
):
return

self._generate_checkpoint_and_upkeep(state, unit, hook="on_eval_epoch_end")

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
self._generate_checkpoint_and_upkeep(state, unit, hook="on_train_end")
def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
self._disable_ckpt_optimality_tracking()

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
num_steps_completed = unit.predict_progress.num_steps_completed
if (
not self._save_every_n_predict_steps
or num_steps_completed % self._save_every_n_predict_steps != 0
):
return

self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_step_end")

def _disable_ckpt_optimality_tracking(self) -> None:
"""
Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints
will not be used. This is useful for eval and predict entrypoints, since checkpoints do not include
model parameters.
"""
if self._best_checkpoint_config:
logger.warning(
"Disabling best_checkpoint_config, since it is not supported for eval or predict entrypoints."
)
self._best_checkpoint_config = None
self._checkpoint_manager._best_checkpoint_config = None

if self._keep_last_n_checkpoints:
logger.warning(
"Disabling keep_last_n_checkpoints, since is not supported for eval or predict entrypoints."
)
self._keep_last_n_checkpoints = None
self._checkpoint_manager._keep_last_n_checkpoints = None

@abc.abstractmethod
def _checkpoint_impl(
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def restore_with_id(
)

def _generate_checkpoint_and_upkeep(
self, state: State, unit: Union[TTrainUnit, TEvalUnit], hook: str
self, state: State, unit: Union[TTrainUnit, TEvalUnit, TPredictUnit], hook: str
) -> bool:
# if we are still checkpointing, this might cause a collective hang, since several
# operations in the base class use the process group. So wait here instead.
Expand Down
10 changes: 10 additions & 0 deletions torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,16 @@ def on_predict_epoch_end(self, state: State) -> None:
"""
pass

def on_checkpoint_save(self, state: State, checkpoint_id: str) -> None:
"""Hook called after successfully saving a checkpoint.
Args:
state: a :class:`~torchtnt.framework.state.State` object containing metadata about the training run.
checkpoint_id: the ID of the checkpoint that was saved. Depending on the storage type, this may be
a path, a URL or a unique identifier.
"""
pass

def on_predict_end(self, state: State) -> None:
"""Hook called after prediction ends.
Expand Down

0 comments on commit d54fe58

Please sign in to comment.