Skip to content

Commit

Permalink
Don't take final ckpt if no more training was done in FIT (pytorch#846)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#846

Reviewed By: JKSenthil

Differential Revision: D58397317

fbshipit-source-id: 31b8f7382059f04cd35f26eafbd957bd53e7f3f0
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Jun 14, 2024
1 parent f9f566b commit 0f72333
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 7 deletions.
38 changes: 38 additions & 0 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,44 @@ def test_save_on_train_end(self) -> None:
],
)

def test_save_on_train_end_on_fit(self) -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 6

for save_every_n_eval_epochs, expected_last_ckpt in [
(None, "epoch_6_train_step_30_eval_step_25"),
(2, "epoch_6_train_step_30_eval_step_30"),
]:
my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2))
train_dataloader = generate_random_dataloader(
dataset_len, input_dim, batch_size
)
eval_dataloader = generate_random_dataloader(
dataset_len, input_dim, batch_size
)
with tempfile.TemporaryDirectory() as temp_dir:
checkpoint_cb = BaseCheckpointSaver(
temp_dir,
save_every_n_epochs=2,
save_every_n_eval_epochs=save_every_n_eval_epochs,
)
fit(
my_unit,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_epochs=max_epochs,
evaluate_every_n_epochs=1,
callbacks=[checkpoint_cb],
)
expected_path = os.path.join(temp_dir, expected_last_ckpt)
self.assertTrue(os.path.exists(expected_path))
self.assertEqual(
checkpoint_cb._checkpoint_manager._ckpt_paths[-1].path,
expected_path,
)

@skip_if_not_distributed
def test_directory_sync_collective(self) -> None:
spawn_multi_process(
Expand Down
31 changes: 24 additions & 7 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks._checkpoint_utils import _get_step_phase_mapping
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.state import State
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
from torchtnt.utils.checkpoint import (
BestCheckpointConfig,
Expand All @@ -25,6 +25,7 @@
get_best_checkpoint_path,
get_latest_checkpoint_path,
MetricData,
Phase,
)
from torchtnt.utils.distributed import PGWrapper
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
Expand Down Expand Up @@ -177,12 +178,28 @@ def _generate_checkpoint_and_upkeep(
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
return False

# 2.1) Make sure that last checkpoint does not already exist
if hook == "on_train_end" and self._does_checkpoint_exist(
checkpoint_path, self._process_group
):
rank_zero_warn("Final checkpoint already exists, skipping.", logger=logger)
return False
if hook == "on_train_end":
# 2.1) Make sure that last checkpoint does not already exist
if self._does_checkpoint_exist(checkpoint_path, self._process_group):
rank_zero_warn(
"Final checkpoint already exists, skipping.", logger=logger
)
return False

# 2.2) If doing fit without eval checkpointing, only consider training progress when
# checking if last checkpoint exists.
if (
state.entry_point == EntryPoint.FIT
and self._save_every_n_eval_epochs is None
and self._checkpoint_manager._ckpt_paths
and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN]
== cast(TTrainUnit, unit).train_progress.num_steps_completed
):
rank_zero_info(
"Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.",
logger=logger,
)
return False

# 3) try to save checkpoint
if not self._checkpoint_impl(
Expand Down

0 comments on commit 0f72333

Please sign in to comment.