diff --git a/composer/checkpoint/save.py b/composer/checkpoint/save.py index 0750b77972..61a467ef63 100644 --- a/composer/checkpoint/save.py +++ b/composer/checkpoint/save.py @@ -30,11 +30,10 @@ def save_state_dict_to_disk( Args: state_dict (Dict[str,Any]): The state dict to save. - destination_dir (str): The directory to save the state dict to. - filename (str): The name of the file to save the state dict to. + destination_file_path (str): The path to save the state dict to. If sharded, + this should be the pth to a directory. Otherwise, it should be a path to a file. overwrite (bool): If True, the file will be overwritten if it exists. save_format (str): The format to save the state dict in. One of 'pt', 'hf', or 'safetensor'. - async_save (bool): If True, the save will be done asynchronously and the function will return with the path of where it was going to be saved Returns: str: The full path to the saved state dict if (sharded is false and rank 0) or if sharded is true, otherwise None.