From 63a7f125330d930dbc3a3166211076a4a5253e2a Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 19 Apr 2024 11:08:56 -0700 Subject: [PATCH] Add option for subclasses to convert model and tokenizer in hf checkpointer (#1121) --- llmfoundry/callbacks/hf_checkpointer.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index baa72a7f66..f899206add 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -12,7 +12,7 @@ import time from multiprocessing.context import SpawnProcess from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -273,6 +273,23 @@ def _all_child_processes_done(self) -> bool: dist.all_reduce(x, reduce_operation='MAX') return x.item() == 0 + def transform_model_and_tokenizer( + self, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase + ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: + """Transform the model and tokenizer before saving. + + This allows a subclass to modify the model and tokenizer before saving. The base class implementation will + make no modifications. + + Args: + model (PreTrainedModel): The model to be transformed. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be transformed. + + Returns: + Tuple[PreTrainedModel, PreTrainedTokenizerBase]: The transformed model and tokenizer. + """ + return model, tokenizer + def _save_checkpoint(self, state: State, logger: Logger): del logger # unused @@ -405,6 +422,10 @@ def dtensor_to_tensor_hook( new_model_instance.load_state_dict(state_dict, assign=True) del state_dict + # Transform the model and tokenizer before saving + new_model_instance, original_tokenizer = self.transform_model_and_tokenizer( + new_model_instance, original_tokenizer) + log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained(temp_save_dir) if original_tokenizer is not None: