Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flag to save the final checkpoint as weights only #3613

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME
from composer.utils.compression import get_compressor, is_compressed_pt
from composer.utils.misc import is_last_batch
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -290,6 +291,7 @@ def __init__(
overwrite: bool = False,
num_checkpoints_to_keep: int = -1,
weights_only: bool = False,
final_weights_only: bool = False,
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
num_concurrent_uploads: int = 1,
upload_timeout_in_seconds: int = 3600,
Expand Down Expand Up @@ -325,6 +327,7 @@ def __init__(
self.all_saved_checkpoints_to_timestamp: dict[str, Timestamp] = {}
self.num_checkpoints_to_keep = num_checkpoints_to_keep
self.weights_only = weights_only
self.final_weights_only = final_weights_only
self.ignore_keys = ignore_keys

self.start_batch = None
Expand Down Expand Up @@ -388,7 +391,7 @@ def fit_start(self, state: State, logger: Logger) -> None:

dist.barrier() # holds all ranks until folder check is done

if is_model_deepspeed(state.model) and self.weights_only:
if is_model_deepspeed(state.model) and (self.weights_only or self.final_weights_only):
raise NotImplementedError('weights_only=True is not supported when using DeepSpeed.')

self.start_batch = state.timestamp.batch
Expand Down Expand Up @@ -472,10 +475,11 @@ def _save_checkpoint(self, state: State, logger: Logger):
# Store before saving so state_dict in checkpoint has reference to latest checkpoint (itself)
self.all_saved_checkpoints_to_timestamp[save_filename] = state.timestamp

weights_only = self.final_weights_only if is_last_batch(state) else self.weights_only
saved_path = checkpoint.save_checkpoint(
state=state,
filename=filename_with_placeholders,
weights_only=self.weights_only,
weights_only=weights_only,
ignore_keys=self.ignore_keys,
)
log.debug(f'Checkpoint locally saved to {saved_path}')
Expand Down
5 changes: 5 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,9 @@ class Trainer:
save_weights_only (bool, optional): Whether to save only the model weights instead of the entire training
state. This parameter has no effect if ``save_folder`` is ``None``. (default: ``False``)

save_final_weights_only (bool, optional): Whether to save only the model weights instead of the entire training
state for the final checkpoint. This parameter has no effect if ``save_folder`` is ``None``. (default: ``False``)

.. seealso:: :class:`~.CheckpointSaver`
save_ignore_keys (list[str] | (dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list
Expand Down Expand Up @@ -1144,6 +1147,7 @@ def __init__(
save_overwrite: bool = False,
save_interval: Union[str, int, Time, Callable[[State, Event], bool]] = '1ep',
save_weights_only: bool = False,
save_final_weights_only: bool = False,
save_ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
save_num_checkpoints_to_keep: int = -1,
save_metrics: bool = False,
Expand Down Expand Up @@ -1559,6 +1563,7 @@ def __init__(
latest_remote_file_name=latest_remote_file_name,
overwrite=save_overwrite,
weights_only=save_weights_only,
final_weights_only=save_final_weights_only,
ignore_keys=save_ignore_keys,
save_interval=save_interval,
num_checkpoints_to_keep=save_num_checkpoints_to_keep,
Expand Down
13 changes: 9 additions & 4 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ class ParallelismType(StringEnum):
TENSOR_PARALLEL = 'tensor_parallel'


def is_last_batch(state: 'State') -> bool:
"""Check if the current batch is the last batch in the epoch."""
elapsed_duration = state.get_elapsed_duration()
return elapsed_duration is not None and elapsed_duration >= 1.0


def create_interval_scheduler(
interval: Union[str, int, 'Time'],
include_end_of_training: bool = True,
Expand Down Expand Up @@ -114,10 +120,9 @@ def check_interval(state: State, event: Event):
if state.previous_timestamp is None:
return False

elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT'

if include_end_of_training and event in final_events and elapsed_duration >= 1.0 and state.timestamp.batch != last_batch_seen:
if include_end_of_training and event in final_events and is_last_batch(
state,
) and state.timestamp.batch != last_batch_seen:
return True

if time_interval.unit in {
Expand Down
Loading