From ff2cf337e4ad8e2c56bca042445037f3121be196 Mon Sep 17 00:00:00 2001 From: v-chen_data Date: Mon, 23 Sep 2024 16:44:28 -0700 Subject: [PATCH] test --- .../inference/test_convert_composer_to_hf.py | 61 ++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index a3916be26c..1a343b0c08 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -15,6 +15,7 @@ import catalogue import pytest import torch +import torch.nn as nn import transformers from composer import ComposerModel, Trainer from composer.loggers import MLFlowLogger @@ -23,7 +24,12 @@ from omegaconf import OmegaConf as om from torch.distributed._tensor.api import DTensor from torch.utils.data import DataLoader -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import ( + AutoConfig, + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizerBase, +) from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename @@ -1636,4 +1642,55 @@ def test_license_file_finder( found_path = _maybe_get_license_filename(str(tmp_path)) assert (found_path == license_file_name - ) if license_file_name is not None else (found_path is None) \ No newline at end of file + ) if license_file_name is not None else (found_path is None) + + +@pytest.mark.parametrize('generation_config', [None, {}, {'max_length': 200}]) +def test_generation_config_variants( + generation_config: Optional[dict[str, Any]] +): + + class MockModel(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + + # Mock a configuration and model with varying generation_config values + config = AutoConfig.from_pretrained('gpt2') + if generation_config is not None: + config.generation_config = generation_config + else: + config.generation_config = None + + mock_model = MockModel(config) + + # Instantiate the callback + checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + ) + + # Mock state and model structure + state = MagicMock() + state.timestamp.batch = 1 + state.is_model_ddp = False + state.model.model = mock_model + state.model.tokenizer = None + + # Mock logger + logger = MagicMock() + + # Call _save_checkpoint method to see if it handles different generation_config gracefully + try: + checkpointer._save_checkpoint( + state=state, + logger=logger, + upload_to_save_folder=False, + register_to_mlflow=False, + ) + print( + f'Test passed: No error when generation_config is {generation_config}' + ) + except Exception as e: + print(f'Test failed: {e} when generation_config is {generation_config}')