Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
samhavens committed Dec 8, 2023
1 parent c13d9fd commit 68f078e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
52 changes: 27 additions & 25 deletions llmfoundry/utils/prompt_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,33 @@ def load_prompts_from_file(prompt_path: str,


def load_prompts_from_remote(prompt_path: str,
prompt_delimiter: Optional[str] = None) -> List[str]:
"""Load a set of prompts from object storage.
Args:
prompt_path (str): Path for text file
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""
backend, _, _ = parse_uri(prompt_path)
if backend in ['', None]:
raise ValueError(
f'prompt_path_str must start with s3:// etc if using object storage')

local_path = prompt_path.split('/')[-1]
get_file(path=prompt_path, destination=local_path, overwrite=True)

with open(local_path, 'r') as f:
prompt_string = f.read()

if prompt_delimiter is None:
return [prompt_string]
return [i for i in prompt_string.split(prompt_delimiter) if i]
prompt_delimiter: Optional[str] = None
) -> List[str]:
"""Load a set of prompts from object storage.
Args:
prompt_path (str): Path for text file
prompt_delimiter (Optional str): Delimiter for text file
If not provided, assumes the prompt file is a single prompt (non-delimited)
Returns:
List of prompt string(s)
"""
backend, _, _ = parse_uri(prompt_path)
if backend in ['', None]:
raise ValueError(
f'prompt_path_str must start with s3:// etc if using object storage'
)

local_path = prompt_path.split('/')[-1]
get_file(path=prompt_path, destination=local_path, overwrite=True)

with open(local_path, 'r') as f:
prompt_string = f.read()

if prompt_delimiter is None:
return [prompt_string]
return [i for i in prompt_string.split(prompt_delimiter) if i]


def load_prompts_from_dataset(dataset_path: str,
Expand Down
5 changes: 4 additions & 1 deletion scripts/inference/hf_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,10 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]):
print(f'\nTokenizing prompts...')
maybe_synchronize()
encode_start = time.time()
encoded_inp = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
encoded_inp = tokenizer(batch,
return_tensors='pt',
padding=True,
truncation=True)
for key, value in encoded_inp.items():
encoded_inp[key] = value.to(model.device)
maybe_synchronize()
Expand Down

0 comments on commit 68f078e

Please sign in to comment.