diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index baa72a7f66..f899206add 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -12,7 +12,7 @@ import time from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -273,6 +273,23 @@ def _all_child_processes_done(self) -> bool: dist.all_reduce(x, reduce_operation='MAX') return x.item() == 0 + def transform_model_and_tokenizer( + self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase + ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + """Transform the model and tokenizer before saving. + + This allows a subclass to modify the model and tokenizer before saving. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be transformed. + + Returns: + Tuple[PreTrainedModel, PreTrainedTokenizerBase]: The transformed model and tokenizer. + """ + return model, tokenizer + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -405,6 +422,10 @@ def dtensor_to_tensor_hook( new_model_instance.load_state_dict(state_dict, assign=True) del state_dict + # Transform the model and tokenizer before saving + new_model_instance, original_tokenizer = self.transform_model_and_tokenizer( + new_model_instance, original_tokenizer) + log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: