Skip to content

Commit

Permalink
Optional CheckpointSaver instantiation inside the Trainer (mosaicml#3334
Browse files Browse the repository at this point in the history
)
  • Loading branch information
antoinebrl committed Jun 3, 2024
1 parent 63b6eda commit 1e1c04d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 7 deletions.
23 changes: 22 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,28 @@ def __init__(
# Checkpoint Saving
self._checkpoint_saver = None
latest_remote_file_name = None
if save_folder is not None:

_checkpoint_savers = [cb for cb in self.state.callbacks if isinstance(cb, CheckpointSaver)]
if len(_checkpoint_savers) >= 1:
if len(_checkpoint_savers) > 1:
log.info('Multiple CheckpointSaver provided as callbacks. Using the first one as reference.')
self._checkpoint_saver = _checkpoint_savers[0]

if self._checkpoint_saver.folder != save_folder:
log.info(f'Using {self._checkpoint_saver.folder} as save_folder.')
save_folder = self._checkpoint_saver.folder

if self._checkpoint_saver.latest_filename is None:
save_latest_filename = None
log.info(f'Using {save_latest_filename} as latest_filename.')
elif self._checkpoint_saver.latest_filename.filename != save_latest_filename:
save_latest_filename = str(self._checkpoint_saver.latest_filename.filename)
log.info(f'Using {save_latest_filename} as latest_filename.')

if self._checkpoint_saver.latest_remote_file_name is not None:
latest_remote_file_name = str(self._checkpoint_saver.latest_remote_file_name.filename)

if self._checkpoint_saver is None and save_folder is not None:
if save_weights_only:
log.info(
'save_weights_only=True now also saves metadata and integrations! Please adjust your workflow accordingly.',
Expand Down
84 changes: 78 additions & 6 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from composer.utils import dist, is_tar, reproducibility
from composer.utils.checkpoint import (
_COMPOSER_STATES_FILENAME,
PartialFilePath,
_ensure_valid_checkpoint,
_write_checkpoint_file,
glob_filter,
Expand Down Expand Up @@ -394,9 +395,9 @@ def test_checkpoint_saver_properly_constructed(
):
mock_validate_credentials = MagicMock()
monkeypatch.setattr(remote_uploader_downloader, '_validate_credentials', mock_validate_credentials)
mock_checkpoint_saver = MagicMock()
monkeypatch.setattr(trainer, 'CheckpointSaver', mock_checkpoint_saver)
self.get_trainer(save_folder=save_folder)

trainer = self.get_trainer(save_folder=save_folder)

expected_prefix = expected_path + '/' if expected_path != '' else expected_path
rest_of_checkpoint_saver_kwargs = {
'filename': 'ep{epoch}-ba{batch}-rank{rank}.pt',
Expand All @@ -409,8 +410,14 @@ def test_checkpoint_saver_properly_constructed(
'num_checkpoints_to_keep': -1,
'ignore_keys': None,
}
expected_folder = expected_path.rstrip('/') if expected_path != '' else '.'
mock_checkpoint_saver.assert_called_once_with(folder=expected_folder, **rest_of_checkpoint_saver_kwargs)
for attr_name, value in rest_of_checkpoint_saver_kwargs.items():
attr = getattr(trainer._checkpoint_saver, attr_name)
if attr_name == 'save_interval':
assert attr.__closure__[-1].cell_contents == Time.from_timestring(value)
elif isinstance(attr, PartialFilePath):
assert attr.filename == value
else:
assert attr == value

@pytest.mark.parametrize('save_interval', ['1tok', '64tok', '65tok'])
@pytest.mark.parametrize('batch_size', [1, 4])
Expand Down Expand Up @@ -616,6 +623,29 @@ def test_checkpoint_intervals(
# we should have one extra call from the fit end checkpoint
assert trainer._checkpoint_saver._save_checkpoint.call_count == expected_save_calls

@pytest.mark.parametrize(('save_folder'), [None, 'local_checkpoints'])
@pytest.mark.parametrize(('save_latest_filename'), [None, 'latest.pt'])
def test_checkpoint_multiple_callbacks(
self,
save_folder: Optional[str],
save_latest_filename: Optional[str],
tmp_path: pathlib.Path,
):
checkpoint_savers = [
CheckpointSaver(str(tmp_path / 'checkpoints1')),
CheckpointSaver(str(tmp_path / 'checkpoints2')),
]

trainer = self.get_trainer(
max_duration='1ep',
callbacks=checkpoint_savers,
save_folder=save_folder,
save_latest_filename=save_latest_filename,
)

assert id(trainer._checkpoint_saver) == id(checkpoint_savers[0])
assert len([cb for cb in trainer.state.callbacks if isinstance(cb, CheckpointSaver)]) == len(checkpoint_savers)


class TestCheckpointLoading:

Expand Down Expand Up @@ -647,6 +677,11 @@ def get_trainer(
eval_dataset = RandomImageDataset()
train_batch_size = 2

callbacks = [DummyStatefulCallback()]
if 'callbacks' in kwargs:
callbacks += kwargs['callbacks']
del kwargs['callbacks']

return Trainer(
model=model,
train_dataloader=DataLoader(
Expand All @@ -670,7 +705,7 @@ def get_trainer(
max_duration=max_duration,
optimizers=optimizer,
schedulers=ExponentialScheduler(gamma=0.9),
callbacks=[DummyStatefulCallback()],
callbacks=callbacks,
**kwargs,
)

Expand Down Expand Up @@ -769,6 +804,43 @@ def test_autoresume(

assert trainer_1.state.run_name == trainer_2.state.run_name

@pytest.mark.parametrize(('save_folder'), [None, 'first'])
def test_autoresume_from_callback(
self,
save_folder: Optional[str],
tmp_path: pathlib.Path,
):
checkpoint_saver = CheckpointSaver(str(tmp_path / 'checkpoints'), latest_filename='latest-rank{rank}.pt')

trainer_1 = self.get_trainer(
file_extension='.pt',
save_folder=save_folder,
device='cpu',
run_name='big-chungus',
autoresume=True,
callbacks=[checkpoint_saver],
)

# trains the model, saving the checkpoint files
trainer_1.fit()
trainer_1.close()

trainer_2 = self.get_trainer(
file_extension='.pt',
save_folder=save_folder,
device='cpu',
run_name='big-chungus',
autoresume=True,
callbacks=[checkpoint_saver],
)

self._assert_weights_equivalent(
trainer_1.state.model,
trainer_2.state.model,
)

assert trainer_1.state.run_name == trainer_2.state.run_name

@pytest.mark.parametrize(
'load_path,load_object_store',
[
Expand Down

0 comments on commit 1e1c04d

Please sign in to comment.