From 68f078e27007c6a50a20aafee27d28900eaf4e41 Mon Sep 17 00:00:00 2001 From: Sam Havens Date: Thu, 7 Dec 2023 17:17:45 -0800 Subject: [PATCH] lint --- llmfoundry/utils/prompt_files.py | 52 +++++++++++++++++--------------- scripts/inference/hf_generate.py | 5 ++- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/llmfoundry/utils/prompt_files.py b/llmfoundry/utils/prompt_files.py index 4cad94d607..14ec005346 100644 --- a/llmfoundry/utils/prompt_files.py +++ b/llmfoundry/utils/prompt_files.py @@ -71,31 +71,33 @@ def load_prompts_from_file(prompt_path: str, def load_prompts_from_remote(prompt_path: str, - prompt_delimiter: Optional[str] = None) -> List[str]: - """Load a set of prompts from object storage. - - Args: - prompt_path (str): Path for text file - prompt_delimiter (Optional str): Delimiter for text file - If not provided, assumes the prompt file is a single prompt (non-delimited) - - Returns: - List of prompt string(s) - """ - backend, _, _ = parse_uri(prompt_path) - if backend in ['', None]: - raise ValueError( - f'prompt_path_str must start with s3:// etc if using object storage') - - local_path = prompt_path.split('/')[-1] - get_file(path=prompt_path, destination=local_path, overwrite=True) - - with open(local_path, 'r') as f: - prompt_string = f.read() - - if prompt_delimiter is None: - return [prompt_string] - return [i for i in prompt_string.split(prompt_delimiter) if i] + prompt_delimiter: Optional[str] = None + ) -> List[str]: + """Load a set of prompts from object storage. + + Args: + prompt_path (str): Path for text file + prompt_delimiter (Optional str): Delimiter for text file + If not provided, assumes the prompt file is a single prompt (non-delimited) + + Returns: + List of prompt string(s) + """ + backend, _, _ = parse_uri(prompt_path) + if backend in ['', None]: + raise ValueError( + f'prompt_path_str must start with s3:// etc if using object storage' + ) + + local_path = prompt_path.split('/')[-1] + get_file(path=prompt_path, destination=local_path, overwrite=True) + + with open(local_path, 'r') as f: + prompt_string = f.read() + + if prompt_delimiter is None: + return [prompt_string] + return [i for i in prompt_string.split(prompt_delimiter) if i] def load_prompts_from_dataset(dataset_path: str, diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index 2a4becc8ad..a639941ec6 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -278,7 +278,10 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]): print(f'\nTokenizing prompts...') maybe_synchronize() encode_start = time.time() - encoded_inp = tokenizer(batch, return_tensors='pt', padding=True, truncation=True) + encoded_inp = tokenizer(batch, + return_tensors='pt', + padding=True, + truncation=True) for key, value in encoded_inp.items(): encoded_inp[key] = value.to(model.device) maybe_synchronize()