From 1e7f909f8c6f2251be84c762c95344e92feb1412 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sun, 17 Sep 2023 18:47:18 -0700 Subject: [PATCH] Fixes a typo default arg (#604) --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- tests/test_hf_conversion_script.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index fe3028ab19..492816ea07 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -46,7 +46,7 @@ def __init__( save_folder: str, save_interval: Union[str, int, Time], huggingface_folder_name: str = 'ba{batch}', - precision: str = 'fp32', + precision: str = 'float32', overwrite: bool = False, ): self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 2a175a04e9..c944dcfc97 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -176,6 +176,10 @@ def get_config( return cast(DictConfig, test_cfg) +def test_callback_inits_with_defaults(): + _ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba') + + @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])