Skip to content

Commit

Permalink
Merge pull request #38 from RexWzh/rex
Browse files Browse the repository at this point in the history
add async feature
  • Loading branch information
RexWzh authored Aug 23, 2023
2 parents 4983ebd + 0f3e190 commit 10bdcdb
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 51 deletions.
15 changes: 12 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,19 @@ on:

jobs:
build:
runs-on: ubuntu-latest
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: [3.8]
os: [ubuntu-latest, macos-latest] # test failed for windows(TODO)
include:
- python-version: 3.7
os: ubuntu-latest
- python-version: 3.9
os: ubuntu-latest
- python-version: '3.10'
os: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -39,4 +48,4 @@ jobs:
token: ${{ secrets.CODECOV_TOKEN }}
flags: unittests
name: codecov-umbrella
fail_ci_if_error: true
fail_ci_if_error: false
24 changes: 14 additions & 10 deletions openai_api_call/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,27 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '0.6.0'
__version__ = '1.0.0'

import os, requests
from .chattool import Chat, Resp, chat_completion
from .chattool import Chat, Resp
from .checkpoint import load_chats, process_chats
from .proxy import proxy_on, proxy_off, proxy_status
from .async_process import async_chat_completion
from . import request


# read API key from the environment variable
if os.environ.get('OPENAI_API_KEY') is not None:
api_key = os.environ.get('OPENAI_API_KEY')
# skip checking the validity of the API key
# if not api_key.startswith("sk-"):
# print("Warning: The default environment variable `OPENAI_API_KEY` is not a valid API key.")
api_key = os.environ.get('OPENAI_API_KEY')

# Read base_url from the environment
if os.environ.get('OPENAI_BASE_URL') is not None:
base_url = os.environ.get("OPENAI_BASE_URL")
elif os.environ.get('OPENAI_API_BASE_URL') is not None:
# adapt to the environment variable of chatgpt-web
base_url = os.environ.get("OPENAI_API_BASE_URL")
else:
api_key = None
base_url = "https://api.openai.com"
base_url = request.normalize_url(base_url)

def show_apikey():
if api_key is not None:
Expand All @@ -39,7 +43,7 @@ def default_prompt(msg:str):

def show_base_url():
"""Show the base url of the API call"""
print(f"Base url:\t{request.base_url}")
print(f"Base url:\t{base_url}")

def debug_log( net_url:str="https://www.baidu.com"
, timeout:int=5
Expand Down
163 changes: 163 additions & 0 deletions openai_api_call/async_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import asyncio, aiohttp
import time, random, warnings, json, os
from typing import List, Dict, Union
from openai_api_call import Chat, Resp, load_chats
import openai_api_call

async def async_post( session
, sem
, url
, data:str
, max_requests:int=1
, timeinterval=0
, timeout=0):
"""Asynchronous post request
Args:
session : aiohttp session
sem : semaphore
url (str): chat completion url
data (str): payload of the request
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
Returns:
str: response text
"""
async with sem:
ntries = 0
while max_requests > 0:
try:
async with session.post(url, data=data, timeout=timeout) as response:
return await response.text()
except Exception as e:
max_requests -= 1
ntries += 1
time.sleep(random.random() * timeinterval)
print(f"Request Failed({ntries}):{e}")
else:
warnings.warn("Maximum number of requests reached!")
return None

async def async_process_msgs( chatlogs:List[List[Dict]]
, chkpoint:str
, api_key:str
, chat_url:str
, max_requests:int=1
, ncoroutines:int=1
, timeout:int=0
, timeinterval:int=0
, **options
)->List[bool]:
"""Process messages asynchronously
Args:
chatlogs (List[List[Dict]]): list of chat logs
chkpoint (str): checkpoint file
api_key (Union[str, None], optional): API key. Defaults to None.
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
ncoroutines (int, optional): number of coroutines. Defaults to 5.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
Returns:
List[bool]: list of responses
"""
# load from checkpoint
chats = load_chats(chkpoint, withid=True) if os.path.exists(chkpoint) else []
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key
}
ncoroutines += 1 # add one for the main coroutine
sem = asyncio.Semaphore(ncoroutines)
locker = asyncio.Lock()

async def chat_complete(ind, locker, chatlog, chkpoint, **options):
payload = {"messages": chatlog}
payload.update(options)
data = json.dumps(payload)
response = await async_post( session=session
, sem=sem
, url=chat_url
, data=data
, max_requests=max_requests
, timeinterval=timeinterval
, timeout=timeout)
resp = Resp(json.loads(response))
if not resp.is_valid():
warnings.warn(f"Invalid response: {resp.error_message}")
return False
## saving files
chatlog.append(resp.message)
chat = Chat(chatlog)
async with locker: # locker | not necessary for normal IO
chat.savewithid(chkpoint, chatid=ind)
return True

async with sem, aiohttp.ClientSession(headers=headers) as session:
tasks = []
for ind, chatlog in enumerate(chatlogs):
if ind < len(chats) and chats[ind] is not None: # skip completed chats
continue
tasks.append(
asyncio.create_task(
chat_complete( ind=ind
, locker=locker
, chatlog=chatlog
, chkpoint=chkpoint
, **options)))
responses = await asyncio.gather(*tasks)
return responses

def async_chat_completion( chatlogs:List[List[Dict]]
, chkpoint:str
, model:str='gpt-3.5-turbo'
, api_key:Union[str, None]=None
, chat_url:Union[str, None]=None
, max_requests:int=1
, ncoroutines:int=1
, timeout:int=0
, timeinterval:int=0
, clearfile:bool=False
, **options
):
"""Asynchronous chat completion
Args:
chatlogs (List[List[Dict]]): list of chat logs
chkpoint (str): checkpoint file
model (str, optional): model to use. Defaults to 'gpt-3.5-turbo'.
api_key (Union[str, None], optional): API key. Defaults to None.
max_requests (int, optional): maximum number of requests to make. Defaults to 1.
ncoroutines (int, optional): number of coroutines. Defaults to 5.
timeout (int, optional): timeout for the API call. Defaults to 0(no timeout).
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
clearfile (bool, optional): whether to clear the checkpoint file. Defaults to False.
Returns:
List[Dict]: list of responses
"""
if clearfile and os.path.exists(chkpoint):
os.remove(chkpoint)
if api_key is None:
api_key = openai_api_call.api_key
assert api_key is not None, "API key is not provided!"
if chat_url is None:
chat_url = os.path.join(openai_api_call.base_url, "v1/chat/completions")
chat_url = openai_api_call.request.normalize_url(chat_url)
# run async process
assert ncoroutines > 0, "ncoroutines must be greater than 0!"
responses = asyncio.run(
async_process_msgs( chatlogs=chatlogs
, chkpoint=chkpoint
, api_key=api_key
, chat_url=chat_url
, max_requests=max_requests
, ncoroutines=ncoroutines
, timeout=timeout
, timeinterval=timeinterval
, model=model
, **options))
return responses
16 changes: 3 additions & 13 deletions openai_api_call/chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
import openai_api_call
from .response import Resp
from .request import chat_completion, valid_models
import signal, time, random
import time, random
import json

# timeout handler
def handler(signum, frame):
raise Exception("API call timed out!")

class Chat():
def __init__( self
, msg:Union[List[Dict], None, str]=None
Expand Down Expand Up @@ -97,15 +93,12 @@ def getresponse( self
# make request
resp = None
numoftries = 0
# Set the timeout handler
signal.signal(signal.SIGALRM, handler)
while max_requests:
try:
# Set the alarm to trigger after `timeout` seconds
signal.alarm(timeout)
# Make the API call
response = chat_completion(
api_key=api_key, messages=msg, model=model, chat_url=self.chat_url, **options)
api_key=api_key, messages=msg, model=model,
chat_url=self.chat_url, timeout=timeout, **options)
time.sleep(random.random() * timeinterval)
resp = Resp(response)
assert resp.is_valid(), "Invalid response with message: " + resp.error_message
Expand All @@ -114,9 +107,6 @@ def getresponse( self
max_requests -= 1
numoftries += 1
print(f"Try again ({numoftries}):{e}\n")
finally:
# Disable the alarm after execution
signal.alarm(0)
else:
raise Exception("Request failed! Try using `debug_log()` to find out the problem " +
"or increase the `max_requests`.")
Expand Down
30 changes: 12 additions & 18 deletions openai_api_call/request.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
# rewrite the request function

from typing import List, Dict, Union
import requests, json
import os
import requests, json, os
from urllib.parse import urlparse, urlunparse

# Read base_url from the environment
if os.environ.get('OPENAI_BASE_URL') is not None:
base_url = os.environ.get("OPENAI_BASE_URL")
elif os.environ.get('OPENAI_API_BASE_URL') is not None:
# adapt to the environment variable of chatgpt-web
base_url = os.environ.get("OPENAI_API_BASE_URL")
else:
base_url = "https://api.openai.com"
import openai_api_call

def is_valid_url(url: str) -> bool:
"""Check if the given URL is valid.
Expand Down Expand Up @@ -48,12 +39,11 @@ def normalize_url(url: str) -> str:
parsed_url = parsed_url._replace(scheme="https")
return urlunparse(parsed_url).replace("///", "//")

base_url = normalize_url(base_url) # normalize base_url

def chat_completion( api_key:str
, messages:List[Dict]
, model:str
, chat_url:Union[str, None]=None
, timeout:int = 0
, **options) -> Dict:
"""Chat completion API call
Expand Down Expand Up @@ -81,16 +71,21 @@ def chat_completion( api_key:str
}
# initialize chat url
if chat_url is None:
base_url = openai_api_call.base_url
chat_url = os.path.join(base_url, "v1/chat/completions")

chat_url = normalize_url(chat_url)
# get response
response = requests.post(chat_url, headers=headers, data=json.dumps(payload))
if timeout <= 0: timeout = None
response = requests.post(
chat_url, headers=headers,
data=json.dumps(payload), timeout=timeout)

if response.status_code != 200:
raise Exception(response.text)
return response.json()

def valid_models(api_key:str, gpt_only:bool=True, url:Union[str, None]=None):
def valid_models(api_key:str, gpt_only:bool=True, base_url:Union[str, None]=None):
"""Get valid models
Request url: https://api.openai.com/v1/models
Expand All @@ -106,12 +101,11 @@ def valid_models(api_key:str, gpt_only:bool=True, url:Union[str, None]=None):
"Authorization": "Bearer " + api_key,
"Content-Type": "application/json"
}
if url is None: url = base_url
models_url = normalize_url(os.path.join(url, "v1/models"))
if base_url is None: base_url = openai_api_call.base_url
models_url = normalize_url(os.path.join(base_url, "v1/models"))
models_response = requests.get(models_url, headers=headers)
if models_response.status_code == 200:
data = models_response.json()
# model_list = data.get("data")
model_list = [model.get("id") for model in data.get("data")]
return [model for model in model_list if "gpt" in model] if gpt_only else model_list
else:
Expand Down
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '0.6.0'

requirements = ['Click>=7.0', 'requests>=2.20', 'tqdm>=4.60', 'docstring_parser>=0.10']
VERSION = '1.0.0'

requirements = ['Click>=7.0', 'requests>=2.20', 'tqdm>=4.60', 'docstring_parser>=0.10', 'aiohttp>=3.8']
test_requirements = ['pytest>=3', 'unittest']

setup(
Expand Down
2 changes: 0 additions & 2 deletions test.py

This file was deleted.

Loading

0 comments on commit 10bdcdb

Please sign in to comment.