Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Sep 23, 2024
1 parent 6b95617 commit ff2cf33
Showing 1 changed file with 59 additions and 2 deletions.
61 changes: 59 additions & 2 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import catalogue
import pytest
import torch
import torch.nn as nn
import transformers
from composer import ComposerModel, Trainer
from composer.loggers import MLFlowLogger
Expand All @@ -23,7 +24,12 @@
from omegaconf import OmegaConf as om
from torch.distributed._tensor.api import DTensor
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import (
AutoConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename
Expand Down Expand Up @@ -1636,4 +1642,55 @@ def test_license_file_finder(

found_path = _maybe_get_license_filename(str(tmp_path))
assert (found_path == license_file_name
) if license_file_name is not None else (found_path is None)
) if license_file_name is not None else (found_path is None)


@pytest.mark.parametrize('generation_config', [None, {}, {'max_length': 200}])
def test_generation_config_variants(
generation_config: Optional[dict[str, Any]]
):

class MockModel(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config

# Mock a configuration and model with varying generation_config values
config = AutoConfig.from_pretrained('gpt2')
if generation_config is not None:
config.generation_config = generation_config
else:
config.generation_config = None

mock_model = MockModel(config)

# Instantiate the callback
checkpointer = HuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
)

# Mock state and model structure
state = MagicMock()
state.timestamp.batch = 1
state.is_model_ddp = False
state.model.model = mock_model
state.model.tokenizer = None

# Mock logger
logger = MagicMock()

# Call _save_checkpoint method to see if it handles different generation_config gracefully
try:
checkpointer._save_checkpoint(
state=state,
logger=logger,
upload_to_save_folder=False,
register_to_mlflow=False,
)
print(
f'Test passed: No error when generation_config is {generation_config}'
)
except Exception as e:
print(f'Test failed: {e} when generation_config is {generation_config}')

0 comments on commit ff2cf33

Please sign in to comment.