diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index f05e7322a8..4e6a501f2f 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -10,6 +10,7 @@ import shutil import tempfile import time +import warnings from multiprocessing.context import SpawnProcess from pathlib import Path from typing import Any, Optional, Sequence, Union @@ -18,6 +19,7 @@ import torch import torch.nn as nn from composer.core import Callback, Event, Precision, State, Time, TimeUnit +from composer.devices import Device from composer.loggers import Logger, MLFlowLogger from composer.models import HuggingFaceModel from composer.utils import ( @@ -161,6 +163,10 @@ class HuggingFaceCheckpointer(Callback): keys ``input_example`` and ``signature``. flatten_imports (Sequence[str]): A sequence of import prefixes that will be flattened when editing MPT files. + final_register_only (bool): If true, only register the model in the MLFlow + registry on the last batch and do not save the HuggingFace checkpoint. If + registration fails or mlflow_registered_model_name is not set, then we will + fallback to saving the HuggingFace checkpoint. """ def __init__( @@ -173,6 +179,7 @@ def __init__( mlflow_registered_model_name: Optional[str] = None, mlflow_logging_config: Optional[dict] = None, flatten_imports: Sequence[str] = ('llmfoundry',), + final_register_only: bool = False, ): _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite @@ -185,8 +192,18 @@ def __init__( self.flatten_imports = flatten_imports self.using_peft = False - # mlflow config setup + self.final_register_only = final_register_only + self.mlflow_registered_model_name = mlflow_registered_model_name + if self.final_register_only and self.mlflow_registered_model_name is None: + self.final_register_only = False + warnings.warn( + 'final_register_only is set to True, but mlflow_registered_model_name is not set. ' + + + f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.', + ) + + # mlflow config setup if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: @@ -249,7 +266,7 @@ def __init__( self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] - self.child_processes: list[SpawnProcess] = [] + self.register_processes: list[SpawnProcess] = [] # Temporary save directory used by child_processes. self.temp_save_dir = None @@ -259,7 +276,17 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: state, event, ) and self.last_checkpoint_batch != state.timestamp.batch: - self._save_checkpoint(state, logger) + is_last_batch = self._is_last_batch(state) + self._save_checkpoint( + state, + logger, + register_to_mlflow=( + self.mlflow_registered_model_name is not None and + is_last_batch + ), + upload_to_save_folder=not self.final_register_only or + not is_last_batch, + ) elif event == Event.INIT: if not isinstance(state.model, HuggingFaceModel): raise ValueError( @@ -300,7 +327,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: # Wait for all child processes spawned by the callback to finish. timeout = 3600 wait_start = time.time() - while not self._all_child_processes_done(): + while not self._all_register_processes_done(state.device): wait_time = time.time() - wait_start if wait_time > timeout: raise TimeoutError( @@ -308,6 +335,19 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: ) time.sleep(2) + if self._any_register_processes_error( + state.device, + ) and self.final_register_only: + log.error( + 'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.', + ) + self._save_checkpoint( + state, + logger, + upload_to_save_folder=True, + register_to_mlflow=False, + ) + # Clean up temporary save directory; all processes are done with it. if self.temp_save_dir is not None: shutil.rmtree(self.temp_save_dir) @@ -339,12 +379,23 @@ def _is_last_batch(self, state: State): return False - def _all_child_processes_done(self) -> bool: - not_done = any(process.is_alive() for process in self.child_processes) - x = torch.tensor(1 if not_done else 0).to(device='cuda') + def _all_register_processes_done(self, device: Device) -> bool: + not_done = any( + process.is_alive() for process in self.register_processes + ) + x = device.tensor_to_device(torch.tensor(1 if not_done else 0)) dist.all_reduce(x, reduce_operation='MAX') return x.item() == 0 + def _any_register_processes_error(self, device: Device) -> bool: + has_errors = any( + process.exitcode is not None and process.exitcode != 0 + for process in self.register_processes + ) + x = device.tensor_to_device(torch.tensor(1 if has_errors else 0)) + dist.all_reduce(x, reduce_operation='MAX') + return x.item() == 1 + def transform_model_and_tokenizer( self, model: PreTrainedModel, @@ -412,7 +463,21 @@ def transform_model_pre_registration( """ return model - def _save_checkpoint(self, state: State, logger: Logger): + def _save_checkpoint( + self, + state: State, + logger: Logger, + upload_to_save_folder: bool, + register_to_mlflow: bool, + ): + """Save a HuggingFace formatted checkpoint. + + Args: + state (State): The training state. + logger (Logger): The logger. + upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder. + register_to_mlflow (bool): Whether to register the model to MLFlow + """ del logger # unused self.last_checkpoint_batch = state.timestamp.batch @@ -548,50 +613,53 @@ def tensor_hook( ].base_model_name_or_path = self.pretrained_model_name log.debug('Saving Hugging Face checkpoint to disk') - # This context manager casts the TE extra state in io.BytesIO format to tensor format - # Needed for proper hf ckpt saving. - context_manager = te.onnx_export( - True, - ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( - ) - with context_manager: - new_model_instance.save_pretrained(temp_save_dir) - if original_tokenizer is not None: - assert isinstance( - original_tokenizer, - PreTrainedTokenizerBase, - ) - original_tokenizer.save_pretrained(temp_save_dir) - - # Only need to edit files for MPT because it has custom code - if new_model_instance.config.model_type == 'mpt': - log.debug('Editing MPT files for HuggingFace compatibility') - edit_files_for_hf_compatibility( - temp_save_dir, - self.flatten_imports, - ) - if self.remote_ud is not None: - for filename in os.listdir(temp_save_dir): - remote_file_name = os.path.join(save_dir, filename) - remote_file_uri = self.remote_ud.remote_backend.get_uri( - remote_file_name, - ) - log.info( - f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}', + if upload_to_save_folder: + # This context manager casts the TE extra state in io.BytesIO format to tensor format + # Needed for proper hf ckpt saving. + context_manager = te.onnx_export( + True, + ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( + ) + with context_manager: + new_model_instance.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance( + original_tokenizer, + PreTrainedTokenizerBase, ) - self.remote_ud.upload_file( - state=state, - remote_file_name=remote_file_name, - file_path=Path(os.path.join(temp_save_dir, filename)), - overwrite=self.overwrite, + original_tokenizer.save_pretrained(temp_save_dir) + + # Only need to edit files for MPT because it has custom code + if new_model_instance.config.model_type == 'mpt': + log.debug('Editing MPT files for HuggingFace compatibility') + edit_files_for_hf_compatibility( + temp_save_dir, + self.flatten_imports, ) + if self.remote_ud is not None: + for filename in os.listdir(temp_save_dir): + remote_file_name = os.path.join(save_dir, filename) + remote_file_uri = self.remote_ud.remote_backend.get_uri( + remote_file_name, + ) + log.info( + f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}', + ) + self.remote_ud.upload_file( + state=state, + remote_file_name=remote_file_name, + file_path=Path( + os.path.join(temp_save_dir, filename), + ), + overwrite=self.overwrite, + ) + dist.barrier() if dist.get_global_rank() == 0: - if self.mlflow_registered_model_name and self._is_last_batch(state): - + if register_to_mlflow: new_model_instance = self.transform_model_pre_registration( new_model_instance, ) @@ -680,7 +748,7 @@ def tensor_hook( # Restore the monitor process. if monitor_process is not None: mlflow_logger.monitor_process = monitor_process # type: ignore - self.child_processes.append(process) + self.register_processes.append(process) # Save the temporary directory to be cleaned up later. if use_temp_dir: diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 73fa4c8d5a..14b7980d57 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -584,7 +584,12 @@ def train(cfg: DictConfig) -> Trainer: ) hf_checkpointer_callback = hf_checkpointer_callbacks[0] - hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger) + hf_checkpointer_callback._save_checkpoint( + trainer.state, + trainer.logger, + upload_to_save_folder=True, + register_to_mlflow=True, + ) return trainer if train_cfg.only_composer_checkpoint: diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index b863e1d0a8..4f1bd63c62 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -9,6 +9,7 @@ import shutil from argparse import Namespace from typing import Any, Callable, Optional, cast +from unittest import mock from unittest.mock import ANY, MagicMock, patch import catalogue @@ -314,9 +315,15 @@ class MockSpawnProcess: multiprocessing, so we need to patch SpawnProcess for tests. """ - def __init__(self, target: Callable, kwargs: dict[str, Any]): + def __init__( + self, + target: Callable, + kwargs: dict[str, Any], + exitcode: int = 0, + ): self.target = target self.kwargs = kwargs + self.exitcode = exitcode def start(self): self.target(**self.kwargs) @@ -325,6 +332,133 @@ def is_alive(self) -> bool: return False +def _create_mlflow_logger_mock() -> MagicMock: + mlflow_logger_mock = MagicMock(spec=MLFlowLogger) + mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} + mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) + mlflow_logger_mock.register_model_with_run_id = MagicMock() + mlflow_logger_mock.model_registry_prefix = '' + mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' + mlflow_logger_mock._run_id = 'mlflow-run-id' + mlflow_logger_mock._enabled = True + mlflow_logger_mock.run_url = 'fake-url' + return mlflow_logger_mock + + +def _create_optimizer(original_model: torch.nn.Module) -> torch.optim.Optimizer: + optimizer_config = _OPTIMIZER_CFG() + optimizer_name = optimizer_config.pop('name') + return build_optimizer( + original_model, + optimizer_name, + optimizer_config, + ) + + +@pytest.mark.gpu +@pytest.mark.parametrize('mlflow_registry_error', [True, False]) +@pytest.mark.parametrize( + 'mlflow_registered_model_name', + [None, 'dummy-registered-name'], +) +@patch('os.cpu_count', MagicMock(return_value=1)) +@patch( + 'llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=MockSpawnProcess, +) +def test_final_register_only( + mlflow_registry_error: bool, + mlflow_registered_model_name: Optional[str], + tiny_ft_dataloader: DataLoader, + tmp_path: pathlib.Path, + build_tiny_mpt: Callable, +): + if mlflow_registry_error and mlflow_registered_model_name is None: + pytest.skip( + 'Cannot test mlflow_registry_error without mlflow_registered_model_name', + ) + + delete_transformers_cache() + + dist.initialize_dist(get_device('gpu')) + + precision_str = 'bfloat16' + + checkpointer_callback = HuggingFaceCheckpointer( + save_folder=os.path.join(tmp_path, 'checkpoints'), + save_interval='1dur', + precision=precision_str, + mlflow_registered_model_name=mlflow_registered_model_name, + final_register_only=True, + ) + + original_model = build_tiny_mpt() + + optimizer = _create_optimizer(original_model) + + mlflow_logger_mock = _create_mlflow_logger_mock() + + checkpointer_callback._save_checkpoint = MagicMock( + wraps=checkpointer_callback._save_checkpoint, + ) + trainer = Trainer( + model=original_model, + device='gpu', + train_dataloader=tiny_ft_dataloader, + max_duration='1ba', + callbacks=[checkpointer_callback], + loggers=[mlflow_logger_mock], + optimizers=optimizer, + save_latest_filename=None, + ) + + with mock.patch( + 'llmfoundry.callbacks.hf_checkpointer.SpawnProcess', + new=lambda target, + kwargs: MockSpawnProcess( + target, + kwargs, + exitcode=1 if mlflow_registry_error else 0, + ), + ): + trainer.fit() + + if mlflow_registered_model_name is not None: + # We should always attempt to register the model once + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + if mlflow_registry_error: + # If the registry fails, we should still save the model + assert mlflow_logger_mock.register_model_with_run_id.call_count == 1 + assert checkpointer_callback._save_checkpoint.call_count == 2 + assert checkpointer_callback._save_checkpoint.call_args_list[ + 0].kwargs == { + 'register_to_mlflow': True, + 'upload_to_save_folder': False, + } + assert checkpointer_callback._save_checkpoint.call_args_list[ + 1].kwargs == { + 'register_to_mlflow': False, + 'upload_to_save_folder': True, + } + else: + # No mlflow_registry_error, so we should only register the model + assert checkpointer_callback._save_checkpoint.call_count == 1 + assert checkpointer_callback._save_checkpoint.call_args_list[ + 0].kwargs == { + 'register_to_mlflow': True, + 'upload_to_save_folder': False, + } + else: + # No mlflow_registered_model_name, so we should only save the checkpoint + assert mlflow_logger_mock.register_model_with_run_id.call_count == 0 + assert checkpointer_callback._save_checkpoint.call_count == 1 + assert checkpointer_callback._save_checkpoint.call_args_list[ + 0].kwargs == { + 'register_to_mlflow': False, + 'upload_to_save_folder': True, + } + + @pytest.mark.gpu @pytest.mark.parametrize('log_to_mlflow', [True, False]) @pytest.mark.parametrize( @@ -368,23 +502,9 @@ def test_huggingface_conversion_callback_interval( original_model = build_tiny_mpt() - optimizer_config = _OPTIMIZER_CFG() - optimizer_name = optimizer_config.pop('name') - optimizer = build_optimizer( - original_model, - optimizer_name, - optimizer_config, - ) + optimizer = _create_optimizer(original_model) - mlflow_logger_mock = MagicMock(spec=MLFlowLogger) - mlflow_logger_mock.state_dict = lambda *args, **kwargs: {} - mlflow_logger_mock.save_model = MagicMock(wraps=_save_model_mock) - mlflow_logger_mock.register_model_with_run_id = MagicMock() - mlflow_logger_mock.model_registry_prefix = '' - mlflow_logger_mock._experiment_id = 'mlflow-experiment-id' - mlflow_logger_mock._run_id = 'mlflow-run-id' - mlflow_logger_mock._enabled = True - mlflow_logger_mock.run_url = 'fake-url' + mlflow_logger_mock = _create_mlflow_logger_mock() checkpointer_callback.transform_model_pre_registration = MagicMock( wraps=checkpointer_callback.transform_model_pre_registration, )