Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
v-chen_data committed Sep 24, 2024
1 parent ff2cf33 commit e36ba4c
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,7 +1647,7 @@ def test_license_file_finder(

@pytest.mark.parametrize('generation_config', [None, {}, {'max_length': 200}])
def test_generation_config_variants(
generation_config: Optional[dict[str, Any]]
generation_config: Optional[dict[str, Any]],
):

class MockModel(nn.Module):
Expand All @@ -1664,33 +1664,24 @@ def __init__(self, config: PretrainedConfig):
config.generation_config = None

mock_model = MockModel(config)

# Instantiate the callback
checkpointer = HuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
)

# Mock state and model structure
logger = MagicMock()
state = MagicMock()
state.timestamp.batch = 1
state.is_model_ddp = False
state.model.model = mock_model
state.model.tokenizer = None

# Mock logger
logger = MagicMock()
checkpointer = HuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
)

# Call _save_checkpoint method to see if it handles different generation_config gracefully
try:
checkpointer._save_checkpoint(
state=state,
logger=logger,
upload_to_save_folder=False,
register_to_mlflow=False,
)
print(
f'Test passed: No error when generation_config is {generation_config}'
)
except Exception as e:
print(f'Test failed: {e} when generation_config is {generation_config}')

0 comments on commit e36ba4c

Please sign in to comment.