From 98a5cfe2da5a7d81072ab32adea5d7f7a74c8f40 Mon Sep 17 00:00:00 2001 From: Anna Pfohl Date: Tue, 28 Nov 2023 13:00:29 -0800 Subject: [PATCH] fix --- scripts/inference/endpoint_generate.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/scripts/inference/endpoint_generate.py b/scripts/inference/endpoint_generate.py index afacb74e20..79a66d38b0 100644 --- a/scripts/inference/endpoint_generate.py +++ b/scripts/inference/endpoint_generate.py @@ -7,7 +7,7 @@ import os import time from argparse import ArgumentParser, Namespace -from typing import cast +from typing import List, cast import pandas as pd import requests @@ -20,6 +20,8 @@ ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY' ENDPOINT_URL_ENV: str = 'ENDPOINT_URL' +PROMPT_DELIMITER = "\n" + def parse_args() -> Namespace: """Parse commandline arguments.""" @@ -78,7 +80,7 @@ def parse_args() -> Namespace: return parser.parse_args() -def load_prompt_string_from_file(prompt_path_str: str): +def load_prompts_from_file(prompt_path_str: str) -> List[str]: if not prompt_path_str.startswith('file::'): raise ValueError('prompt_path_str must start with "file::".') _, prompt_file_path = prompt_path_str.split('file::', maxsplit=1) @@ -87,8 +89,8 @@ def load_prompt_string_from_file(prompt_path_str: str): raise FileNotFoundError( f'{prompt_file_path=} does not match any existing files.') with open(prompt_file_path, 'r') as f: - prompt_string = ''.join(f.readlines()) - return prompt_string + prompt_string = f.read() + return prompt_string.split(PROMPT_DELIMITER) async def main(args: Namespace) -> None: @@ -148,8 +150,8 @@ async def generate(session: aiohttp.ClientSession, batch: int, response = await resp.json() except requests.JSONDecodeError: raise Exception( - f'Bad response: {resp.status_code} {resp.reason}' - ) # type: ignore + f'Bad response: {resp.status_code} {resp.reason}' # type: ignore + ) else: raise Exception( f'Bad response: {resp.status_code} {resp.content.decode().strip()}' # type: ignore