diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2c6e36e..9d2d263 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -10,10 +10,19 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: - python-version: [3.7, 3.8, 3.9] + python-version: [3.8] + os: [ubuntu-latest, macos-latest] # test failed for windows(TODO) + include: + - python-version: 3.7 + os: ubuntu-latest + - python-version: 3.9 + os: ubuntu-latest + - python-version: '3.10' + os: ubuntu-latest + steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -39,4 +48,4 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} flags: unittests name: codecov-umbrella - fail_ci_if_error: true \ No newline at end of file + fail_ci_if_error: false \ No newline at end of file diff --git a/openai_api_call/__init__.py b/openai_api_call/__init__.py index f2b8df2..f0d4d98 100644 --- a/openai_api_call/__init__.py +++ b/openai_api_call/__init__.py @@ -2,23 +2,27 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '0.6.0' +__version__ = '1.0.0' import os, requests -from .chattool import Chat, Resp, chat_completion +from .chattool import Chat, Resp from .checkpoint import load_chats, process_chats from .proxy import proxy_on, proxy_off, proxy_status +from .async_process import async_chat_completion from . import request - # read API key from the environment variable -if os.environ.get('OPENAI_API_KEY') is not None: - api_key = os.environ.get('OPENAI_API_KEY') - # skip checking the validity of the API key - # if not api_key.startswith("sk-"): - # print("Warning: The default environment variable `OPENAI_API_KEY` is not a valid API key.") +api_key = os.environ.get('OPENAI_API_KEY') + +# Read base_url from the environment +if os.environ.get('OPENAI_BASE_URL') is not None: + base_url = os.environ.get("OPENAI_BASE_URL") +elif os.environ.get('OPENAI_API_BASE_URL') is not None: + # adapt to the environment variable of chatgpt-web + base_url = os.environ.get("OPENAI_API_BASE_URL") else: - api_key = None + base_url = "https://api.openai.com" +base_url = request.normalize_url(base_url) def show_apikey(): if api_key is not None: @@ -39,7 +43,7 @@ def default_prompt(msg:str): def show_base_url(): """Show the base url of the API call""" - print(f"Base url:\t{request.base_url}") + print(f"Base url:\t{base_url}") def debug_log( net_url:str="https://www.baidu.com" , timeout:int=5 diff --git a/openai_api_call/async_process.py b/openai_api_call/async_process.py new file mode 100644 index 0000000..37d6bcf --- /dev/null +++ b/openai_api_call/async_process.py @@ -0,0 +1,163 @@ +import asyncio, aiohttp +import time, random, warnings, json, os +from typing import List, Dict, Union +from openai_api_call import Chat, Resp, load_chats +import openai_api_call + +async def async_post( session + , sem + , url + , data:str + , max_requests:int=1 + , timeinterval=0 + , timeout=0): + """Asynchronous post request + + Args: + session : aiohttp session + sem : semaphore + url (str): chat completion url + data (str): payload of the request + max_requests (int, optional): maximum number of requests to make. Defaults to 1. + timeinterval (int, optional): time interval between two API calls. Defaults to 0. + timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). + + Returns: + str: response text + """ + async with sem: + ntries = 0 + while max_requests > 0: + try: + async with session.post(url, data=data, timeout=timeout) as response: + return await response.text() + except Exception as e: + max_requests -= 1 + ntries += 1 + time.sleep(random.random() * timeinterval) + print(f"Request Failed({ntries}):{e}") + else: + warnings.warn("Maximum number of requests reached!") + return None + +async def async_process_msgs( chatlogs:List[List[Dict]] + , chkpoint:str + , api_key:str + , chat_url:str + , max_requests:int=1 + , ncoroutines:int=1 + , timeout:int=0 + , timeinterval:int=0 + , **options + )->List[bool]: + """Process messages asynchronously + + Args: + chatlogs (List[List[Dict]]): list of chat logs + chkpoint (str): checkpoint file + api_key (Union[str, None], optional): API key. Defaults to None. + max_requests (int, optional): maximum number of requests to make. Defaults to 1. + ncoroutines (int, optional): number of coroutines. Defaults to 5. + timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). + timeinterval (int, optional): time interval between two API calls. Defaults to 0. + + Returns: + List[bool]: list of responses + """ + # load from checkpoint + chats = load_chats(chkpoint, withid=True) if os.path.exists(chkpoint) else [] + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer " + api_key + } + ncoroutines += 1 # add one for the main coroutine + sem = asyncio.Semaphore(ncoroutines) + locker = asyncio.Lock() + + async def chat_complete(ind, locker, chatlog, chkpoint, **options): + payload = {"messages": chatlog} + payload.update(options) + data = json.dumps(payload) + response = await async_post( session=session + , sem=sem + , url=chat_url + , data=data + , max_requests=max_requests + , timeinterval=timeinterval + , timeout=timeout) + resp = Resp(json.loads(response)) + if not resp.is_valid(): + warnings.warn(f"Invalid response: {resp.error_message}") + return False + ## saving files + chatlog.append(resp.message) + chat = Chat(chatlog) + async with locker: # locker | not necessary for normal IO + chat.savewithid(chkpoint, chatid=ind) + return True + + async with sem, aiohttp.ClientSession(headers=headers) as session: + tasks = [] + for ind, chatlog in enumerate(chatlogs): + if ind < len(chats) and chats[ind] is not None: # skip completed chats + continue + tasks.append( + asyncio.create_task( + chat_complete( ind=ind + , locker=locker + , chatlog=chatlog + , chkpoint=chkpoint + , **options))) + responses = await asyncio.gather(*tasks) + return responses + +def async_chat_completion( chatlogs:List[List[Dict]] + , chkpoint:str + , model:str='gpt-3.5-turbo' + , api_key:Union[str, None]=None + , chat_url:Union[str, None]=None + , max_requests:int=1 + , ncoroutines:int=1 + , timeout:int=0 + , timeinterval:int=0 + , clearfile:bool=False + , **options + ): + """Asynchronous chat completion + + Args: + chatlogs (List[List[Dict]]): list of chat logs + chkpoint (str): checkpoint file + model (str, optional): model to use. Defaults to 'gpt-3.5-turbo'. + api_key (Union[str, None], optional): API key. Defaults to None. + max_requests (int, optional): maximum number of requests to make. Defaults to 1. + ncoroutines (int, optional): number of coroutines. Defaults to 5. + timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). + timeinterval (int, optional): time interval between two API calls. Defaults to 0. + clearfile (bool, optional): whether to clear the checkpoint file. Defaults to False. + + Returns: + List[Dict]: list of responses + """ + if clearfile and os.path.exists(chkpoint): + os.remove(chkpoint) + if api_key is None: + api_key = openai_api_call.api_key + assert api_key is not None, "API key is not provided!" + if chat_url is None: + chat_url = os.path.join(openai_api_call.base_url, "v1/chat/completions") + chat_url = openai_api_call.request.normalize_url(chat_url) + # run async process + assert ncoroutines > 0, "ncoroutines must be greater than 0!" + responses = asyncio.run( + async_process_msgs( chatlogs=chatlogs + , chkpoint=chkpoint + , api_key=api_key + , chat_url=chat_url + , max_requests=max_requests + , ncoroutines=ncoroutines + , timeout=timeout + , timeinterval=timeinterval + , model=model + , **options)) + return responses \ No newline at end of file diff --git a/openai_api_call/chattool.py b/openai_api_call/chattool.py index a25c431..621e1c4 100644 --- a/openai_api_call/chattool.py +++ b/openai_api_call/chattool.py @@ -4,13 +4,9 @@ import openai_api_call from .response import Resp from .request import chat_completion, valid_models -import signal, time, random +import time, random import json -# timeout handler -def handler(signum, frame): - raise Exception("API call timed out!") - class Chat(): def __init__( self , msg:Union[List[Dict], None, str]=None @@ -97,15 +93,12 @@ def getresponse( self # make request resp = None numoftries = 0 - # Set the timeout handler - signal.signal(signal.SIGALRM, handler) while max_requests: try: - # Set the alarm to trigger after `timeout` seconds - signal.alarm(timeout) # Make the API call response = chat_completion( - api_key=api_key, messages=msg, model=model, chat_url=self.chat_url, **options) + api_key=api_key, messages=msg, model=model, + chat_url=self.chat_url, timeout=timeout, **options) time.sleep(random.random() * timeinterval) resp = Resp(response) assert resp.is_valid(), "Invalid response with message: " + resp.error_message @@ -114,9 +107,6 @@ def getresponse( self max_requests -= 1 numoftries += 1 print(f"Try again ({numoftries}):{e}\n") - finally: - # Disable the alarm after execution - signal.alarm(0) else: raise Exception("Request failed! Try using `debug_log()` to find out the problem " + "or increase the `max_requests`.") diff --git a/openai_api_call/request.py b/openai_api_call/request.py index 6481d7f..90d84b1 100644 --- a/openai_api_call/request.py +++ b/openai_api_call/request.py @@ -1,18 +1,9 @@ # rewrite the request function from typing import List, Dict, Union -import requests, json -import os +import requests, json, os from urllib.parse import urlparse, urlunparse - -# Read base_url from the environment -if os.environ.get('OPENAI_BASE_URL') is not None: - base_url = os.environ.get("OPENAI_BASE_URL") -elif os.environ.get('OPENAI_API_BASE_URL') is not None: - # adapt to the environment variable of chatgpt-web - base_url = os.environ.get("OPENAI_API_BASE_URL") -else: - base_url = "https://api.openai.com" +import openai_api_call def is_valid_url(url: str) -> bool: """Check if the given URL is valid. @@ -48,12 +39,11 @@ def normalize_url(url: str) -> str: parsed_url = parsed_url._replace(scheme="https") return urlunparse(parsed_url).replace("///", "//") -base_url = normalize_url(base_url) # normalize base_url - def chat_completion( api_key:str , messages:List[Dict] , model:str , chat_url:Union[str, None]=None + , timeout:int = 0 , **options) -> Dict: """Chat completion API call @@ -81,16 +71,21 @@ def chat_completion( api_key:str } # initialize chat url if chat_url is None: + base_url = openai_api_call.base_url chat_url = os.path.join(base_url, "v1/chat/completions") chat_url = normalize_url(chat_url) # get response - response = requests.post(chat_url, headers=headers, data=json.dumps(payload)) + if timeout <= 0: timeout = None + response = requests.post( + chat_url, headers=headers, + data=json.dumps(payload), timeout=timeout) + if response.status_code != 200: raise Exception(response.text) return response.json() -def valid_models(api_key:str, gpt_only:bool=True, url:Union[str, None]=None): +def valid_models(api_key:str, gpt_only:bool=True, base_url:Union[str, None]=None): """Get valid models Request url: https://api.openai.com/v1/models @@ -106,12 +101,11 @@ def valid_models(api_key:str, gpt_only:bool=True, url:Union[str, None]=None): "Authorization": "Bearer " + api_key, "Content-Type": "application/json" } - if url is None: url = base_url - models_url = normalize_url(os.path.join(url, "v1/models")) + if base_url is None: base_url = openai_api_call.base_url + models_url = normalize_url(os.path.join(base_url, "v1/models")) models_response = requests.get(models_url, headers=headers) if models_response.status_code == 200: data = models_response.json() - # model_list = data.get("data") model_list = [model.get("id") for model in data.get("data")] return [model for model in model_list if "gpt" in model] if gpt_only else model_list else: diff --git a/setup.py b/setup.py index d9d5a98..d07bb15 100644 --- a/setup.py +++ b/setup.py @@ -7,10 +7,9 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '0.6.0' - -requirements = ['Click>=7.0', 'requests>=2.20', 'tqdm>=4.60', 'docstring_parser>=0.10'] +VERSION = '1.0.0' +requirements = ['Click>=7.0', 'requests>=2.20', 'tqdm>=4.60', 'docstring_parser>=0.10', 'aiohttp>=3.8'] test_requirements = ['pytest>=3', 'unittest'] setup( diff --git a/test.py b/test.py deleted file mode 100644 index c4c1a23..0000000 --- a/test.py +++ /dev/null @@ -1,2 +0,0 @@ -from openai_api_call import * -debug_log() \ No newline at end of file diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..ba93de4 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,29 @@ +import openai_api_call, time +from openai_api_call import Chat, process_chats +from openai_api_call.async_process import async_chat_completion +openai_api_call.api_key="free-123" +openai_api_call.base_url = "https://api.wzhecnu.cn" + + +# langs = ["Python", "Julia", "C++", "C", "Java", "JavaScript", "C#", "Go", "R", "Ruby"] +langs = ["Python", "Julia", "C++"] +chatlogs = [ + [{"role": "user", "content": f"Print hello using {lang}"}] for lang in langs +] + +def test_async_process(): + chkpoint = "test_async.jsonl" + t = time.time() + resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, ncoroutines=3) + assert all(resp) + print(f"Time elapsed: {time.time() - t:.2f}s") + +def test_normal_process(): + chkpoint = "test_nomal.jsonl" + def data2chat(data): + chat = Chat(data) + chat.getresponse() + return chat + t = time.time() + process_chats(chatlogs, data2chat, chkpoint, clearfile=True) + print(f"Time elapsed: {time.time() - t:.2f}s") diff --git a/tests/test_request.py b/tests/test_request.py index 3b1da5d..86dde6a 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,8 +1,6 @@ from openai_api_call import debug_log, Resp from openai_api_call.request import normalize_url, is_valid_url, valid_models import openai_api_call -openai_api_call.api_key="free-123" -openai_api_call.request.base_url = "api.wzhecnu.cn" api_key = openai_api_call.api_key def test_valid_models():