From c2ae387d8804f677f73c5beaae72b1e0905a3260 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Wed, 27 Dec 2023 20:27:36 +0800 Subject: [PATCH 1/2] change params --- chattool/__init__.py | 2 +- chattool/asynctool.py | 50 +++++++++++++++++++----------------------- chattool/chattool.py | 11 ++++++---- tests/test_async.py | 12 +++++----- tests/test_function.py | 16 +++++++------- 5 files changed, 43 insertions(+), 48 deletions(-) diff --git a/chattool/__init__.py b/chattool/__init__.py index 9f5efb7..96fd3f1 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -129,7 +129,7 @@ def debug_log( net_url:str="https://www.baidu.com" if test_response: print("\nTest response:", message) chat = Chat(message) - chat.getresponse(max_requests=3) + chat.getresponse(max_tries=3) chat.print_log() print("\nDebug is finished.") diff --git a/chattool/asynctool.py b/chattool/asynctool.py index c474470..5e1cdf8 100644 --- a/chattool/asynctool.py +++ b/chattool/asynctool.py @@ -10,7 +10,7 @@ async def async_post( session , url , data:str , headers:Dict - , max_requests:int=1 + , max_tries:int=1 , timeinterval=0 , timeout=0): """Asynchronous post request @@ -21,7 +21,7 @@ async def async_post( session url (str): chat completion url data (str): payload of the request headers (Dict): request headers - max_requests (int, optional): maximum number of requests to make. Defaults to 1. + max_tries (int, optional): maximum number of requests to make. Defaults to 1. timeinterval (int, optional): time interval between two API calls. Defaults to 0. timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). @@ -30,7 +30,7 @@ async def async_post( session """ async with sem: ntries = 0 - while max_requests > 0: + while max_tries > 0: try: async with session.post(url, headers=headers, data=data, timeout=timeout) as response: resp = await response.text() @@ -38,7 +38,7 @@ async def async_post( session assert resp.is_valid(), resp.error_message return resp except Exception as e: - max_requests -= 1 + max_tries -= 1 ntries += 1 time.sleep(random.random() * timeinterval) print(f"Request Failed({ntries}):{e}") @@ -50,11 +50,10 @@ async def async_process_msgs( chatlogs:List[List[Dict]] , chkpoint:str , api_key:str , chat_url:str - , max_requests:int=1 - , ncoroutines:int=1 + , max_tries:int=1 + , nproc:int=1 , timeout:int=0 , timeinterval:int=0 - , max_tokens:Union[Callable, None]=None , **options )->List[bool]: """Process messages asynchronously @@ -63,8 +62,8 @@ async def async_process_msgs( chatlogs:List[List[Dict]] chatlogs (List[List[Dict]]): list of chat logs chkpoint (str): checkpoint file api_key (Union[str, None], optional): API key. Defaults to None. - max_requests (int, optional): maximum number of requests to make. Defaults to 1. - ncoroutines (int, optional): number of coroutines. Defaults to 5. + max_tries (int, optional): maximum number of requests to make. Defaults to 1. + nproc (int, optional): number of coroutines. Defaults to 5. timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). timeinterval (int, optional): time interval between two API calls. Defaults to 0. @@ -72,29 +71,27 @@ async def async_process_msgs( chatlogs:List[List[Dict]] List[bool]: list of responses """ # load from checkpoint - chats = load_chats(chkpoint, withid=True) if os.path.exists(chkpoint) else [] + chats = load_chats(chkpoint) if os.path.exists(chkpoint) else [] chats.extend([None] * (len(chatlogs) - len(chats))) costs = [0] * len(chatlogs) headers = { "Content-Type": "application/json", "Authorization": "Bearer " + api_key } - ncoroutines += 1 # add one for the main coroutine - sem = asyncio.Semaphore(ncoroutines) + nproc += 1 # add one for the main coroutine + sem = asyncio.Semaphore(nproc) locker = asyncio.Lock() async def chat_complete(ind, locker, chat_log, chkpoint, **options): payload = {"messages": chat_log} payload.update(options) - if max_tokens is not None: - payload['max_tokens'] = max_tokens(chat_log) data = json.dumps(payload) resp = await async_post( session=session , sem=sem , url=chat_url , data=data , headers=headers - , max_requests=max_requests + , max_tries=max_tries , timeinterval=timeinterval , timeout=timeout) ## saving files @@ -130,8 +127,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] , model:str='gpt-3.5-turbo' , api_key:Union[str, None]=None , chat_url:Union[str, None]=None - , max_requests:int=1 - , ncoroutines:int=1 + , max_tries:int=1 , nproc:int=1 , timeout:int=0 , timeinterval:int=0 @@ -139,7 +135,8 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] , notrun:bool=False , msg2log:Union[Callable, None]=None , data2chat:Union[Callable, None]=None - , max_tokens:Union[Callable, int, None]=None + , max_requests:int=-1 + , ncoroutines:int=1 , **options ): """Asynchronous chat completion @@ -149,8 +146,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] chkpoint (str): checkpoint file model (str, optional): model to use. Defaults to 'gpt-3.5-turbo'. api_key (Union[str, None], optional): API key. Defaults to None. - max_requests (int, optional): maximum number of requests to make. Defaults to 1. - ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1. + max_tries (int, optional): maximum number of requests to make. Defaults to 1. nproc (int, optional): number of coroutines. Defaults to 1. timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). timeinterval (int, optional): time interval between two API calls. Defaults to 0. @@ -161,8 +157,8 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] Defaults to None. data2chat (Union[Callable, None], optional): function to convert data to Chat object. Defaults to None. - max_tokens (Union[Callable, int, None], optional): function to calculate the maximum - number of tokens for the API call. Defaults to None. + max_requests (int, optional): (Deprecated)maximum number of requests to make. Defaults to -1. + ncoroutines (int, optional): (Deprecated)number of coroutines. Defaults to 1. Returns: List[Dict]: list of responses @@ -184,20 +180,18 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] chat_url = os.path.join(chattool.base_url, "v1/chat/completions") chat_url = chattool.request.normalize_url(chat_url) # run async process - assert ncoroutines > 0, "ncoroutines must be greater than 0!" - if isinstance(max_tokens, int): - max_tokens = lambda chat_log: max_tokens + assert nproc > 0, "nproc must be greater than 0!" + max_tries = max(max_tries, max_requests) args = { "chatlogs": chatlogs, "chkpoint": chkpoint, "api_key": api_key, "chat_url": chat_url, - "max_requests": max_requests, - "ncoroutines": nproc, + "max_tries": max_tries, + "nproc": nproc, "timeout": timeout, "timeinterval": timeinterval, "model": model, - "max_tokens": max_tokens, **options } if notrun: # when use in Jupyter Notebook diff --git a/chattool/chattool.py b/chattool/chattool.py index faeef3f..84cbc33 100644 --- a/chattool/chattool.py +++ b/chattool/chattool.py @@ -170,20 +170,22 @@ def print_log(self, sep: Union[str, None]=None): # Part2: response and async response def getresponse( self - , max_requests:int=1 + , max_tries:int = 1 , timeout:int = 0 , timeinterval:int = 0 , update:bool = True , stream:bool = False + , max_requests:int=-1 , **options)->Resp: """Get the API response Args: - max_requests (int, optional): maximum number of requests to make. Defaults to 1. + max_tries (int, optional): maximum number of requests to make. Defaults to 1. timeout (int, optional): timeout for the API call. Defaults to 0(no timeout). timeinterval (int, optional): time interval between two API calls. Defaults to 0. update (bool, optional): whether to update the chat log. Defaults to True. options (dict, optional): other options like `temperature`, `top_p`, etc. + max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit Returns: Resp: API response @@ -194,10 +196,11 @@ def getresponse( self func_call = options.get('function_call', self.function_call) if api_key is None: warnings.warn("API key is not set!") msg, resp, numoftries = self.chat_log, None, 0 + max_tries = max(max_tries, max_requests) if stream: # TODO: add the `usage` key to the response warnings.warn("stream mode is not supported yet! Use `async_stream_responses()` instead.") # make requests - while max_requests: + while max_tries: try: # make API Call if funcs is not None: options['functions'] = funcs @@ -209,7 +212,7 @@ def getresponse( self assert resp.is_valid(), resp.error_message break except Exception as e: - max_requests -= 1 + max_tries-= 1 numoftries += 1 time.sleep(random.random() * timeinterval) print(f"Try again ({numoftries}):{e}\n") diff --git a/tests/test_async.py b/tests/test_async.py index b820bb1..480d458 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -44,8 +44,8 @@ async def show_resp(chat): def test_async_process(): chkpoint = testpath + "test_async.jsonl" t = time.time() - resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, ncoroutines=3) - resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, ncoroutines=3) + resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3) + resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, nproc=3) assert all(resp) print(f"Time elapsed: {time.time() - t:.2f}s") @@ -55,7 +55,7 @@ def test_failed_async(): chattool.api_key = "sk-invalid" chkpoint = testpath + "test_async_fail.jsonl" words = ["hello", "Can you help me?", "Do not translate this word", "I need help with my homework"] - resp = async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3) + resp = async_chat_completion(words, chkpoint, clearfile=True, nproc=3) chattool.api_key = api_key def test_async_process_withfunc(): @@ -66,15 +66,13 @@ def msg2log(msg): chat.system("translate the words from English to Chinese") chat.user(msg) return chat.chat_log - def max_tokens(chat_log): - return Chat(chat_log).prompt_token() - async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3, max_tokens=max_tokens, msg2log=msg2log) + async_chat_completion(words, chkpoint, clearfile=True, nproc=3, msg2log=msg2log) def test_normal_process(): chkpoint = testpath + "test_nomal.jsonl" def data2chat(data): chat = Chat(data) - chat.getresponse(max_requests=3) + chat.getresponse(max_tries=3) return chat t = time.time() process_chats(chatlogs, data2chat, chkpoint, clearfile=True) diff --git a/tests/test_function.py b/tests/test_function.py index 462cca8..a4513b5 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -33,7 +33,7 @@ def test_call_weather(): chat = Chat("What's the weather like in Boston?") - resp = chat.getresponse(functions=functions, function_call='auto', max_requests=3) + resp = chat.getresponse(functions=functions, function_call='auto', max_tries=3) # TODO: wrap the response if resp.finish_reason == 'function_call': # test response from chat api @@ -54,12 +54,12 @@ def test_auto_response(): chat = Chat("What's the weather like in Boston?") chat.functions, chat.function_call = functions, 'auto' chat.name2func = name2func - chat.autoresponse(max_requests=2) + chat.autoresponse(max_tries=2) chat.print_log() chat.clear() # response with nonempty content chat.user("what is the result of 1+1, and What's the weather like in Boston?") - chat.autoresponse(max_requests=2) + chat.autoresponse(max_tries=2) # generate docstring from functions def add(a: int, b: int) -> int: @@ -100,20 +100,20 @@ def test_add_and_mult(): chat.name2func = {'add': add} # dictionary of functions chat.function_call = 'auto' # auto decision # run until success: maxturns=-1 - chat.autoresponse(max_requests=3, display=True, timeinterval=2) + chat.autoresponse(max_tries=3, display=True, timeinterval=2) # response should be finished chat.simplify() chat.print_log() # use the setfuncs method chat = Chat("find the value of 124842 * 3423424") chat.setfuncs([add, mult]) # multi choice - chat.autoresponse(max_requests=3, timeinterval=2) + chat.autoresponse(max_tries=3, timeinterval=2) chat.simplify() # simplify the chat log chat.print_log() # test multichoice chat.clear() chat.user("find the value of 23723 + 12312, and 23723 * 12312") - chat.autoresponse(max_requests=3, timeinterval=2) + chat.autoresponse(max_tries=3, timeinterval=2) def test_mock_resp(): chat = Chat("find the sum of 1235 and 3423") @@ -122,12 +122,12 @@ def test_mock_resp(): para = {'name': 'add', 'arguments': '{\n "a": 1235,\n "b": 3423\n}'} chat.assistant(content=None, function_call=para) chat.callfunction() - chat.getresponse(max_requests=2) + chat.getresponse(max_tries=2) def test_use_exec_function(): chat = Chat("find the result of sqrt(121314)") chat.setfuncs([exec_python_code]) - chat.autoresponse(max_requests=2) + chat.autoresponse(max_tries=2) def test_find_permutation_group(): pass From c4a2656841be0fb963eb2dada382993011be5b29 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Wed, 27 Dec 2023 20:29:53 +0800 Subject: [PATCH 2/2] update version --- chattool/__init__.py | 2 +- chattool/chattool.py | 2 +- setup.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/chattool/__init__.py b/chattool/__init__.py index 96fd3f1..0bb2d8d 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '3.0.0' +__version__ = '3.0.1' import os, sys, requests from .chattool import Chat, Resp diff --git a/chattool/chattool.py b/chattool/chattool.py index 84cbc33..5667fae 100644 --- a/chattool/chattool.py +++ b/chattool/chattool.py @@ -212,7 +212,7 @@ def getresponse( self assert resp.is_valid(), resp.error_message break except Exception as e: - max_tries-= 1 + max_tries -= 1 numoftries += 1 time.sleep(random.random() * timeinterval) print(f"Try again ({numoftries}):{e}\n") diff --git a/setup.py b/setup.py index f9681d2..2e95bcd 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '3.0.0' +VERSION = '3.0.1' requirements = [ 'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',