diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 1a343b0c08..1c5b22e305 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -1647,7 +1647,7 @@ def test_license_file_finder( @pytest.mark.parametrize('generation_config', [None, {}, {'max_length': 200}]) def test_generation_config_variants( - generation_config: Optional[dict[str, Any]] + generation_config: Optional[dict[str, Any]], ): class MockModel(nn.Module): @@ -1664,24 +1664,18 @@ def __init__(self, config: PretrainedConfig): config.generation_config = None mock_model = MockModel(config) - - # Instantiate the callback - checkpointer = HuggingFaceCheckpointer( - save_folder='test', - save_interval='1ba', - ) - - # Mock state and model structure + logger = MagicMock() state = MagicMock() state.timestamp.batch = 1 state.is_model_ddp = False state.model.model = mock_model state.model.tokenizer = None - # Mock logger - logger = MagicMock() + checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + ) - # Call _save_checkpoint method to see if it handles different generation_config gracefully try: checkpointer._save_checkpoint( state=state, @@ -1689,8 +1683,5 @@ def __init__(self, config: PretrainedConfig): upload_to_save_folder=False, register_to_mlflow=False, ) - print( - f'Test passed: No error when generation_config is {generation_config}' - ) except Exception as e: print(f'Test failed: {e} when generation_config is {generation_config}')