Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Nov 28, 2023
1 parent 4b6543e commit 98a5cfe
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions scripts/inference/endpoint_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
import time
from argparse import ArgumentParser, Namespace
from typing import cast
from typing import List, cast

import pandas as pd
import requests
Expand All @@ -20,6 +20,8 @@
ENDPOINT_API_KEY_ENV: str = 'ENDPOINT_API_KEY'
ENDPOINT_URL_ENV: str = 'ENDPOINT_URL'

PROMPT_DELIMITER = "\n"


def parse_args() -> Namespace:
"""Parse commandline arguments."""
Expand Down Expand Up @@ -78,7 +80,7 @@ def parse_args() -> Namespace:
return parser.parse_args()


def load_prompt_string_from_file(prompt_path_str: str):
def load_prompts_from_file(prompt_path_str: str) -> List[str]:
if not prompt_path_str.startswith('file::'):
raise ValueError('prompt_path_str must start with "file::".')
_, prompt_file_path = prompt_path_str.split('file::', maxsplit=1)
Expand All @@ -87,8 +89,8 @@ def load_prompt_string_from_file(prompt_path_str: str):
raise FileNotFoundError(
f'{prompt_file_path=} does not match any existing files.')
with open(prompt_file_path, 'r') as f:
prompt_string = ''.join(f.readlines())
return prompt_string
prompt_string = f.read()
return prompt_string.split(PROMPT_DELIMITER)


async def main(args: Namespace) -> None:
Expand Down Expand Up @@ -148,8 +150,8 @@ async def generate(session: aiohttp.ClientSession, batch: int,
response = await resp.json()
except requests.JSONDecodeError:
raise Exception(
f'Bad response: {resp.status_code} {resp.reason}'
) # type: ignore
f'Bad response: {resp.status_code} {resp.reason}' # type: ignore
)
else:
raise Exception(
f'Bad response: {resp.status_code} {resp.content.decode().strip()}' # type: ignore
Expand Down

0 comments on commit 98a5cfe

Please sign in to comment.