Skip to content

Commit

Permalink
Merge branch 'main' into replace-fsdp-args
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Sep 16, 2024
2 parents e1aefa5 + a862d6e commit dab8791
Show file tree
Hide file tree
Showing 10 changed files with 369 additions and 73 deletions.
160 changes: 114 additions & 46 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import shutil
import tempfile
import time
import warnings
from multiprocessing.context import SpawnProcess
from pathlib import Path
from typing import Any, Optional, Sequence, Union
Expand All @@ -18,6 +19,7 @@
import torch
import torch.nn as nn
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.devices import Device
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
from composer.utils import (
Expand Down Expand Up @@ -161,6 +163,10 @@ class HuggingFaceCheckpointer(Callback):
keys ``input_example`` and ``signature``.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
final_register_only (bool): If true, only register the model in the MLFlow
registry on the last batch and do not save the HuggingFace checkpoint. If
registration fails or mlflow_registered_model_name is not set, then we will
fallback to saving the HuggingFace checkpoint.
"""

def __init__(
Expand All @@ -173,6 +179,7 @@ def __init__(
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
flatten_imports: Sequence[str] = ('llmfoundry',),
final_register_only: bool = False,
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
Expand All @@ -185,8 +192,18 @@ def __init__(
self.flatten_imports = flatten_imports
self.using_peft = False

# mlflow config setup
self.final_register_only = final_register_only

self.mlflow_registered_model_name = mlflow_registered_model_name
if self.final_register_only and self.mlflow_registered_model_name is None:
self.final_register_only = False
warnings.warn(
'final_register_only is set to True, but mlflow_registered_model_name is not set. '
+
f'Defaulting to final_register_only=False and saving the HuggingFace checkpoint to {save_folder=}.',
)

# mlflow config setup
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
Expand Down Expand Up @@ -249,7 +266,7 @@ def __init__(
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

self.child_processes: list[SpawnProcess] = []
self.register_processes: list[SpawnProcess] = []
# Temporary save directory used by child_processes.
self.temp_save_dir = None

Expand All @@ -259,7 +276,17 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
state,
event,
) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(state, logger)
is_last_batch = self._is_last_batch(state)
self._save_checkpoint(
state,
logger,
register_to_mlflow=(
self.mlflow_registered_model_name is not None and
is_last_batch
),
upload_to_save_folder=not self.final_register_only or
not is_last_batch,
)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(
Expand Down Expand Up @@ -300,14 +327,27 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
# Wait for all child processes spawned by the callback to finish.
timeout = 3600
wait_start = time.time()
while not self._all_child_processes_done():
while not self._all_register_processes_done(state.device):
wait_time = time.time() - wait_start
if wait_time > timeout:
raise TimeoutError(
f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.',
)
time.sleep(2)

if self._any_register_processes_error(
state.device,
) and self.final_register_only:
log.error(
'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.',
)
self._save_checkpoint(
state,
logger,
upload_to_save_folder=True,
register_to_mlflow=False,
)

# Clean up temporary save directory; all processes are done with it.
if self.temp_save_dir is not None:
shutil.rmtree(self.temp_save_dir)
Expand Down Expand Up @@ -339,12 +379,23 @@ def _is_last_batch(self, state: State):

return False

def _all_child_processes_done(self) -> bool:
not_done = any(process.is_alive() for process in self.child_processes)
x = torch.tensor(1 if not_done else 0).to(device='cuda')
def _all_register_processes_done(self, device: Device) -> bool:
not_done = any(
process.is_alive() for process in self.register_processes
)
x = device.tensor_to_device(torch.tensor(1 if not_done else 0))
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 0

def _any_register_processes_error(self, device: Device) -> bool:
has_errors = any(
process.exitcode is not None and process.exitcode != 0
for process in self.register_processes
)
x = device.tensor_to_device(torch.tensor(1 if has_errors else 0))
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 1

def transform_model_and_tokenizer(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -412,7 +463,21 @@ def transform_model_pre_registration(
"""
return model

def _save_checkpoint(self, state: State, logger: Logger):
def _save_checkpoint(
self,
state: State,
logger: Logger,
upload_to_save_folder: bool,
register_to_mlflow: bool,
):
"""Save a HuggingFace formatted checkpoint.
Args:
state (State): The training state.
logger (Logger): The logger.
upload_to_save_folder (bool): Whether to upload the HF checkpoint to the save folder.
register_to_mlflow (bool): Whether to register the model to MLFlow
"""
del logger # unused

self.last_checkpoint_batch = state.timestamp.batch
Expand Down Expand Up @@ -548,50 +613,53 @@ def tensor_hook(
].base_model_name_or_path = self.pretrained_model_name

log.debug('Saving Hugging Face checkpoint to disk')
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
if upload_to_save_folder:
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir, filename)),
overwrite=self.overwrite,
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(
os.path.join(temp_save_dir, filename),
),
overwrite=self.overwrite,
)

dist.barrier()

if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(state):

if register_to_mlflow:
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
Expand Down Expand Up @@ -680,7 +748,7 @@ def tensor_hook(
# Restore the monitor process.
if monitor_process is not None:
mlflow_logger.monitor_process = monitor_process # type: ignore
self.child_processes.append(process)
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
Expand Down
7 changes: 7 additions & 0 deletions llmfoundry/command_utils/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ClusterDoesNotExistError,
FailedToConnectToDatabricksError,
FailedToCreateSQLConnectionError,
InsufficientPermissionsError,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -454,6 +455,12 @@ def fetch(
sparkSession,
)
except Exception as e:
from pyspark.errors import AnalysisException
if isinstance(e, AnalysisException):
if 'INSUFFICIENT_PERMISSIONS' in e.message: # pyright: ignore
raise InsufficientPermissionsError(
action=f'reading from {tablename}',
) from e
raise RuntimeError(
f'Error in get rows from {tablename}. Restart sparkSession and try again',
) from e
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/command_utils/data_prep/convert_text_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ def convert_text_to_mds(
index_path = os.path.join(local_output_folder, 'index.json')
with open(index_path, 'r') as index_file:
if not json.load(index_file)['shards']:
raise DatasetTooSmallError()
raise DatasetTooSmallError(
reason='No shards were created when converting text to MDS.',
)

# Write a done file with the args and object names
write_done_file(local_output_folder, args_str, object_names)
Expand Down
7 changes: 6 additions & 1 deletion llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,12 @@ def train(cfg: DictConfig) -> Trainer:
)

hf_checkpointer_callback = hf_checkpointer_callbacks[0]
hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger)
hf_checkpointer_callback._save_checkpoint(
trainer.state,
trainer.logger,
upload_to_save_folder=True,
register_to_mlflow=True,
)
return trainer

if train_cfg.only_composer_checkpoint:
Expand Down
22 changes: 20 additions & 2 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
ALLOWED_RESPONSE_KEYS,
ChatTemplateError,
ConsecutiveRepeatedChatRolesError,
DatasetTooSmallError,
IncorrectMessageKeyQuantityError,
InvalidContentTypeError,
InvalidExampleTypeError,
Expand Down Expand Up @@ -1013,7 +1014,7 @@ def dataset_mapper(example: dict):
raise MisconfiguredHfDatasetError(
dataset_name=dataset_name,
split=split,
)
) from error
if error is not None:
log.error('Error during data prep')
raise error
Expand All @@ -1033,7 +1034,24 @@ def build_from_streaming(
*args: Any,
**kwargs: Any,
) -> StreamingFinetuningDataset:
return self.streaming_dataset_class(*args, **kwargs)
dataset = self.streaming_dataset_class(*args, **kwargs)
num_canonical_nodes = dataset.num_canonical_nodes
num_samples = dataset.num_samples
if num_canonical_nodes is None:
num_physical_nodes = dist.get_world_size(
) // dist.get_local_world_size()
if num_samples < num_physical_nodes:
raise DatasetTooSmallError(
f'{num_samples=} is less than {dist.get_world_size() // dist.get_local_world_size()}, the number of physical nodes. ',
)

if num_canonical_nodes is not None and num_samples < num_canonical_nodes:
raise DatasetTooSmallError(
f'{num_samples=} is less than {num_canonical_nodes=}. ' +
'Please check your index.json file and ensure that your dataset has been written out correctly.'
+ 'If this was intended, reduce num_canonical_nodes.',
)
return dataset


dataset_constructor = DatasetConstructor()
Expand Down
8 changes: 5 additions & 3 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
config_overrides: Optional[dict[str, Any]] = None,
use_logits: bool = True,
shift_labels: bool = False,
peft_config: Optional['PeftConfig'] = None,
peft_config: Optional[dict[str, Any]] = None,
allow_embedding_resizing: bool = False,
use_train_metrics: bool = True,
additional_train_metrics: Optional[list] = None,
Expand All @@ -92,8 +92,6 @@ def __init__(

model = self.transform_model(model)

self.prepare_inner_model(model, init_device)

metrics, eval_metrics = self.build_metrics(
use_train_metrics=use_train_metrics,
additional_train_metrics=additional_train_metrics,
Expand Down Expand Up @@ -121,6 +119,10 @@ def __init__(
should_save_peft_only=should_save_peft_only,
)

# Prepare for FSDP needs to happen after the super init, so that any model
# architecture changes are completed
self.prepare_inner_model(self.model, init_device)

def loss(self, outputs: ModelOutput, batch: Mapping):
if self.config.use_return_dict:
return outputs['loss']
Expand Down
Loading

0 comments on commit dab8791

Please sign in to comment.