diff --git a/README.md b/README.md index b4b8cf7..dfe1c3b 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ def add(a: int, b: int) -> int: return a + b # 传输函数 chat = Chat() -chat.setfuncs([add]) # 传入函数列表,可以是多个函数 +chat.settools([add]) # 传入函数列表,可以是多个函数 chat.user("请计算 1 + 2") # 自动调用工具 chat.autoresponse(display=True) @@ -139,10 +139,11 @@ chat.autoresponse(display=True) ## 开源协议 -这个项目使用 MIT 协议开源。 +使用 MIT 协议开源。 ## 更新日志 +- 当前版本 `3.2.1`,简化异步处理和串行处理的接口,更新子模块名称,避免冲突 - 版本 `2.3.0`,支持调用外部工具,异步处理数据,以及模型微调功能 - 版本 `2.0.0` 开始,更名为 `chattool` - 版本 `1.0.0` 开始,支持异步处理数据 diff --git a/chattool/__init__.py b/chattool/__init__.py index 17a27f9..e72e45e 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '3.2.0' +__version__ = '3.2.1' import os, sys, requests from .chattype import Chat, Resp @@ -13,7 +13,8 @@ from .asynctool import async_chat_completion from .functioncall import generate_json_schema, exec_python_code from typing import Union -import dotenv +import dotenv +import loguru raw_env_text = f"""# Description: Env file for ChatTool. # Current version: {__version__} @@ -33,11 +34,24 @@ """ def load_envs(env:Union[None, str, dict]=None): - """Read the environment variables for the API call""" + """Load the environment variables for the API call + + Args: + env (Union[None, str, dict], optional): The environment file or the environment variables. Defaults to None. + + Returns: + bool: True if the environment variables are loaded successfully. + + Example: + load_envs("envfile.env") + load_envs({"OPENAI_API_KEY":"your_api_key"}) + load_envs() # load from the environment variables + """ global api_key, base_url, api_base, model # update the environment variables - if isinstance(env, str): - dotenv.load_dotenv(env, override=True) + if isinstance(env, str) and not dotenv.load_dotenv(env, override=True): + loguru.logger.warning(f"Failed to load the environment file: {env}") + return False elif isinstance(env, dict): for key, value in env.items(): os.environ[key] = value @@ -157,8 +171,8 @@ def debug_log( net_url:str="https://www.baidu.com" # Get model list if test_model: - print("\nThe model list(contains gpt):") - print(Chat().get_valid_models()) + print("\nThe model list:") + print(Chat().get_valid_models(gpt_only=False)) # Test hello world if test_response: diff --git a/chattool/chattype.py b/chattool/chattype.py index 3d5200c..89dc9b8 100644 --- a/chattool/chattype.py +++ b/chattool/chattype.py @@ -9,6 +9,7 @@ import os from .functioncall import generate_json_schema, delete_dialogue_assist from pprint import pformat +from loguru import logger class Chat(): def __init__( self @@ -113,11 +114,11 @@ def clear(self): def copy(self): """Copy the chat log""" - return Chat(self._chat_log) + return Chat(self.chat_log) def deepcopy(self): """Deep copy the Chat object""" - return Chat( self._chat_log + return Chat( self.chat_log , api_key=self.api_key , chat_url=self.chat_url , model=self.model @@ -210,6 +211,7 @@ def getresponse( self , tools:Union[None, List[Dict]]=None , tool_choice:Union[None, str]=None , max_requests:int=-1 + , tool_type:str='tool_choice' , functions:Union[None, List[Dict]]=None , function_call:Union[None, str]=None , **options)->Resp: @@ -221,7 +223,8 @@ def getresponse( self timeinterval (int, optional): time interval between two API calls. Defaults to 0. update (bool, optional): whether to update the chat log. Defaults to True. options (dict, optional): other options like `temperature`, `top_p`, etc. - max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit + max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit) + tool_type (str, optional): type of the tool. Defaults to 'tool_choice'. Returns: Resp: API response @@ -232,10 +235,14 @@ def getresponse( self # function call & tool call tool_choice, tools = tool_choice or self.tool_choice, tools or self.tools function_call, functions = function_call or self.function_call, functions or self.functions - if tool_choice is not None: - options['tool_choice'], options['tools'] = tool_choice, tools - elif function_call is not None: - options['function_call'], options['functions'] = function_call, functions + if tool_type == 'function_call': + if function_call is not None: + options['function_call'], options['functions'] = function_call, functions + else: + if tool_choice is not None: + if tool_type != 'tool_choice': + logger.warning(f"Unknown tool type {tool_type}, use 'tool_choice' by default.") + options['tool_choice'], options['tools'] = tool_choice, tools # if api_key is None: warnings.warn("API key is not set!") msg, resp, numoftries = self.chat_log, None, 0 max_tries = max(max_tries, max_requests) @@ -312,8 +319,8 @@ def setfuncs(self, funcs:List): def settools(self, tools:List): """Initialize tools for tool calls""" - self._functions =[generate_json_schema(func) for func in tools] - self.tool_choice = 'auto' + self.functions =[generate_json_schema(func) for func in tools] + self.tool_choice = 'auto' # the only difference from setfuncs self.name2func = {tool.__name__:tool for tool in tools} return True @@ -368,7 +375,7 @@ def callfunction(self): def autoresponse( self , display:bool=False , maxturns:int=3 - , use_tool:bool=True + , tool_type:str='tool_choice' , **options): """Get the response automatically @@ -380,22 +387,22 @@ def autoresponse( self Returns: bool: whether the response is finished """ - if use_tool: - options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto' - else: + if tool_type == 'function_call': options['functions'], options['function_call'] = self.functions, self.function_call or 'auto' + else: + options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto' show = lambda msg: print(self.display_role_content(msg)) - resp = self.getresponse(**options) + resp = self.getresponse(tool_type=tool_type, **options) if display: show(resp.message) while self.iswaiting() and maxturns != 0: # call api and update the result - if use_tool: + if tool_type != 'function_call': self.calltools(display=display) else: result, name, _ = self.callfunction() self.function(result, name) if display: show(self[-1]) - resp = self.getresponse(**options) + resp = self.getresponse(tool_type=tool_type, **options) if display: show(resp.message) maxturns -= 1 return True @@ -553,7 +560,6 @@ def name2func(self, name2func:Dict): assert isinstance(name2func, dict), "name2func should be a dict" self._name2func = name2func - @property def chat_log(self): """Chat history""" diff --git a/setup.py b/setup.py index 90e6b14..170c0b3 100644 --- a/setup.py +++ b/setup.py @@ -7,11 +7,12 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '3.2.0' +VERSION = '3.2.1' requirements = [ 'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8', - 'tqdm>=4.60', 'docstring_parser>=0.10', "python-dotenv>=0.17.0"] + 'tqdm>=4.60', 'docstring_parser>=0.10', "python-dotenv>=0.17.0", + 'loguru>=0.7'] test_requirements = ['pytest>=3', 'unittest'] setup( diff --git a/tests/test_tools.py b/tests/test_tools.py index caf76fa..7efac37 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -92,6 +92,21 @@ def mult(a:int, b:int) -> int: """ return a * b +def test_func_and_tool(): + chat = Chat("find the value of 124842 * 3423424 + 121312") + chat.settools([add, mult]) # multi choice + chat1 = chat.deepcopy() + chat1.autoresponse(tool_type='tool_choice') + chat2 = chat.deepcopy() + chat2.autoresponse(tool_type='function_call') + # setfuncs + chat.setfuncs([add, mult]) + chat1 = chat.deepcopy() + chat1.autoresponse(tool_type='tool_choice') + chat2 = chat.deepcopy() + chat2.autoresponse(tool_type='function_call') + + def test_add_and_mult(): tools = [{ 'type':'function', @@ -127,7 +142,6 @@ def test_add_and_mult(): chat4.user("find the value of (23723 * 1322312 ) + 12312") chat4.autoresponse(max_tries=3, display=True, timeinterval=2) - def test_use_exec_function(): chat = Chat("find the result of sqrt(121314)") chat.settools([exec_python_code])