From 397338eead15dff83c561c63a535a2a35a066fd9 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Tue, 5 Mar 2024 23:39:04 +0800 Subject: [PATCH 1/3] fix async response --- README.md | 4 +++- chattool/__init__.py | 4 ++-- chattool/chattool.py | 26 +++++++++++++++----------- setup.py | 2 +- tests/__init__.py | 4 ++-- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index acb0a30..6729624 100644 --- a/README.md +++ b/README.md @@ -26,12 +26,14 @@ pip install chattool --upgrade ```bash export OPENAI_API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" -export OPENAI_API_BASEL="https://api.example.com/v1" +export OPENAI_API_BASE="https://api.example.com/v1" export OPENAI_API_BASE_URL="https://api.example.com" # 可选 ``` Win 在系统中设置环境变量。 +注:环境变量中,`OPENAI_API_BASE` 优先于 `OPENAI_API_BASE_URL`,二者选其一即可。 + ### 示例 示例1,模拟多轮对话: diff --git a/chattool/__init__.py b/chattool/__init__.py index c80b1ef..a302704 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '3.1.0' +__version__ = '3.1.1' import os, sys, requests from .chattool import Chat, Resp @@ -27,7 +27,7 @@ def load_envs(env:Union[None, str, dict]=None): # else: load from environment variables api_key = os.getenv('OPENAI_API_KEY') base_url = os.getenv('OPENAI_API_BASE_URL') or "https://api.openai.com" - api_base = os.getenv('OPENAI_API_BASE', os.path.join(base_url, 'v1')) + api_base = os.getenv('OPENAI_API_BASE') or os.path.join(base_url, 'v1') base_url = request.normalize_url(base_url) api_base = request.normalize_url(api_base) model = os.getenv('OPENAI_API_MODEL', "gpt-3.5-turbo") diff --git a/chattool/chattool.py b/chattool/chattool.py index 963161c..06afc9a 100644 --- a/chattool/chattool.py +++ b/chattool/chattool.py @@ -251,20 +251,24 @@ async def async_stream_responses(self, timeout:int=0, textonly:bool=False): if not line: break # strip the prefix of `data: {...}` strline = line.decode().lstrip('data:').strip() + if strline == '[DONE]': break # skip empty line if not strline: continue # read the json string - line = json.loads(strline) - # wrap the response - resp = Resp(line) - # stop if the response is finished - if resp.finish_reason == 'stop': break - # deal with the message - if 'content' not in resp.delta: continue - if textonly: - yield resp.delta_content - else: - yield resp + try: + # wrap the response + resp = Resp(json.loads(strline)) + # stop if the response is finished + if resp.finish_reason == 'stop': break + # deal with the message + if 'content' not in resp.delta: continue + if textonly: + yield resp.delta_content + else: + yield resp + except Exception as e: + print(f"Error: {e}, line: {strline}") + break # Part3: function call def iswaiting(self): diff --git a/setup.py b/setup.py index a0286c9..256cfb8 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '3.1.0' +VERSION = '3.1.1' requirements = [ 'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8', diff --git a/tests/__init__.py b/tests/__init__.py index 977c100..14aa6e9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,6 @@ """Unit test package for chattool.""" -from chattool import Chat +from chattool import Chat, debug_log import os if not os.path.exists('tests'): @@ -13,7 +13,7 @@ def test_simple(): chat = Chat() chat.user("Hello!") chat.getresponse() - chat.print_log() + debug_log() assert chat.chat_log[0] == {"role": "user", "content": "Hello!"} assert len(chat.chat_log) == 2 \ No newline at end of file From 4ef16e9662a4cd06f93a0af47008e0b9731158ea Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Fri, 8 Mar 2024 17:16:19 +0800 Subject: [PATCH 2/3] remove tests for finetune --- chattool/__init__.py | 3 ++- tests/__init__.py | 10 ---------- tests/{test_finetune.py => _test_finetune.py} | 0 tests/test_async.py | 11 ++++++++++- 4 files changed, 12 insertions(+), 12 deletions(-) rename tests/{test_finetune.py => _test_finetune.py} (100%) diff --git a/chattool/__init__.py b/chattool/__init__.py index a302704..6ff6a56 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -30,7 +30,7 @@ def load_envs(env:Union[None, str, dict]=None): api_base = os.getenv('OPENAI_API_BASE') or os.path.join(base_url, 'v1') base_url = request.normalize_url(base_url) api_base = request.normalize_url(api_base) - model = os.getenv('OPENAI_API_MODEL', "gpt-3.5-turbo") + model = os.getenv('OPENAI_API_MODEL') or "gpt-3.5-turbo" return True def save_envs(env_file:str): @@ -91,6 +91,7 @@ def debug_log( net_url:str="https://www.baidu.com" Returns: bool: True if the debug is finished. """ + print("Current version:", __version__) # Network test try: requests.get(net_url, timeout=timeout) diff --git a/tests/__init__.py b/tests/__init__.py index 14aa6e9..dca7600 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,13 +7,3 @@ os.mkdir('tests') if not os.path.exists('tests/testfiles'): os.mkdir('tests/testfiles') - -def test_simple(): - # set api_key in the environment variable - chat = Chat() - chat.user("Hello!") - chat.getresponse() - debug_log() - assert chat.chat_log[0] == {"role": "user", "content": "Hello!"} - assert len(chat.chat_log) == 2 - \ No newline at end of file diff --git a/tests/test_finetune.py b/tests/_test_finetune.py similarity index 100% rename from tests/test_finetune.py rename to tests/_test_finetune.py diff --git a/tests/test_async.py b/tests/test_async.py index 163ff84..f26f663 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,5 +1,5 @@ import chattool, time, os -from chattool import Chat, process_chats +from chattool import Chat, process_chats, debug_log from chattool.asynctool import async_chat_completion import asyncio, pytest @@ -10,6 +10,15 @@ ] testpath = 'tests/testfiles/' +def test_simple(): + # set api_key in the environment variable + chat = Chat() + chat.user("Hello!") + chat.getresponse() + debug_log() + assert chat.chat_log[0] == {"role": "user", "content": "Hello!"} + assert len(chat.chat_log) == 2 + def test_apikey(): assert chattool.api_key.startswith("sk-") From dad777dfc32b8de304cb4ee28d7d0630c47ab411 Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Fri, 8 Mar 2024 17:45:25 +0800 Subject: [PATCH 3/3] fix valid model options --- chattool/asynctool.py | 7 ++++++- chattool/chattool.py | 13 ++++++++++++- chattool/request.py | 11 +++++------ tests/test_async.py | 4 ++-- tests/test_request.py | 12 ++++++++---- 5 files changed, 33 insertions(+), 14 deletions(-) diff --git a/chattool/asynctool.py b/chattool/asynctool.py index 5e1cdf8..c1daece 100644 --- a/chattool/asynctool.py +++ b/chattool/asynctool.py @@ -177,7 +177,12 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] api_key = chattool.api_key assert api_key is not None, "API key is not provided!" if chat_url is None: - chat_url = os.path.join(chattool.base_url, "v1/chat/completions") + if chattool.api_base: + chat_url = os.path.join(chattool.api_base, "chat/completions") + elif chattool.base_url: + chat_url = os.path.join(chattool.base_url, "v1/chat/completions") + else: + raise Exception("chat_url is not provided!") chat_url = chattool.request.normalize_url(chat_url) # run async process assert nproc > 0, "nproc must be greater than 0!" diff --git a/chattool/chattool.py b/chattool/chattool.py index 06afc9a..fafea84 100644 --- a/chattool/chattool.py +++ b/chattool/chattool.py @@ -357,7 +357,8 @@ def get_valid_models(self, gpt_only:bool=True)->List[str]: Returns: List[str]: valid models """ - return valid_models(self.api_key, self.base_url, gpt_only=gpt_only) + model_url = os.path.join(self.api_base, 'models') + return valid_models(self.api_key, model_url, gpt_only=gpt_only) # Part5: properties and setters @property @@ -399,6 +400,11 @@ def base_url(self): """Get base url""" return self._base_url + @property + def api_base(self): + """Get base url""" + return self._api_base + @property def functions(self): """Get functions""" @@ -428,6 +434,11 @@ def chat_url(self, chat_url:str): def base_url(self, base_url:str): """Set base url""" self._base_url = base_url + + @api_base.setter + def api_base(self, api_base:str): + """Set base url""" + self._api_base = api_base @functions.setter def functions(self, functions:List[Dict]): diff --git a/chattool/request.py b/chattool/request.py index 76f8e9f..a080f7b 100644 --- a/chattool/request.py +++ b/chattool/request.py @@ -81,7 +81,7 @@ def chat_completion( api_key:str raise Exception(response.text) return response.json() -def valid_models(api_key:str, base_url:str, gpt_only:bool=True): +def valid_models(api_key:str, model_url:str, gpt_only:bool=True): """Get valid models Request url: https://api.openai.com/v1/models @@ -97,14 +97,13 @@ def valid_models(api_key:str, base_url:str, gpt_only:bool=True): "Authorization": "Bearer " + api_key, "Content-Type": "application/json" } - models_url = normalize_url(os.path.join(base_url, "v1/models")) - models_response = requests.get(models_url, headers=headers) - if models_response.status_code == 200: - data = models_response.json() + model_response = requests.get(normalize_url(model_url), headers=headers) + if model_response.status_code == 200: + data = model_response.json() model_list = [model.get("id") for model in data.get("data")] return [model for model in model_list if "gpt" in model] if gpt_only else model_list else: - raise Exception(models_response.text) + raise Exception(model_response.text) def loadfile(api_key:str, base_url:str, file:str, purpose:str='fine-tune'): """Upload a file that can be used across various endpoints/features. diff --git a/tests/test_async.py b/tests/test_async.py index f26f663..6867736 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -12,10 +12,10 @@ def test_simple(): # set api_key in the environment variable + debug_log() chat = Chat() chat.user("Hello!") chat.getresponse() - debug_log() assert chat.chat_log[0] == {"role": "user", "content": "Hello!"} assert len(chat.chat_log) == 2 @@ -54,7 +54,7 @@ def test_async_process(): chkpoint = testpath + "test_async.jsonl" t = time.time() resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3) - resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, nproc=3) + resp = async_chat_completion(chatlogs, chkpoint, nproc=3) assert all(resp) print(f"Time elapsed: {time.time() - t:.2f}s") diff --git a/tests/test_request.py b/tests/test_request.py index 08a1c87..7a14bbb 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -5,14 +5,18 @@ create_finetune_job, list_finetune_job, retrievejob, listevents, canceljob, deletemodel ) -import pytest, chattool -api_key, base_url = chattool.api_key, chattool.base_url +import pytest, chattool, os +api_key, base_url, api_base = chattool.api_key, chattool.base_url, chattool.api_base testpath = 'tests/testfiles/' def test_valid_models(): - models = valid_models(api_key, base_url, gpt_only=False) + if chattool.api_base: + model_url = os.path.join(chattool.api_base, 'models') + else: + model_url = os.path.join(chattool.base_url, 'v1/models') + models = valid_models(api_key, model_url, gpt_only=False) assert len(models) >= 1 - models = valid_models(api_key, base_url, gpt_only=True) + models = valid_models(api_key, model_url, gpt_only=True) assert len(models) >= 1 assert 'gpt-3.5-turbo' in models