From 735550f88af7cc2e1d86f1eea24ba6e685a6d53e Mon Sep 17 00:00:00 2001 From: rex <1073853456@qq.com> Date: Sun, 9 Jun 2024 00:24:02 +0800 Subject: [PATCH] simplify tool calls --- chattool/__init__.py | 2 +- chattool/chattype.py | 61 ++++++++++++++++++------------------------ setup.py | 2 +- tests/test_function.py | 15 +++++++++-- tests/test_tools.py | 6 ++--- 5 files changed, 44 insertions(+), 42 deletions(-) diff --git a/chattool/__init__.py b/chattool/__init__.py index a543156..17a27f9 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '3.1.7' +__version__ = '3.2.0' import os, sys, requests from .chattype import Chat, Resp diff --git a/chattool/chattype.py b/chattool/chattype.py index fdcde0a..3d5200c 100644 --- a/chattool/chattype.py +++ b/chattool/chattype.py @@ -20,7 +20,6 @@ def __init__( self , model:Union[None, str]=None , tools:Union[None, List[Dict]]=None , tool_choice:Union[None, str]=None - , name2tool:Union[None, Dict]=None , functions:Union[None, List[Dict]]=None , function_call:Union[None, str]=None , name2func:Union[None, Dict]=None): @@ -35,10 +34,10 @@ def __init__( self model (Union[None, str], optional): model to use. Defaults to None. tools (Union[None, List[Dict]], optional): tools to use, each tool is a JSON Schema. Defaults to None. tool_choice (Union[None, str], optional): method to choose the tool. Defaults to None. Choices: ['auto', '$NameOfTheTool', 'none'] - name2tool (Union[None, Dict], optional): name to tool mapping. Defaults to None. + name2func (Union[None, Dict], optional): name to function mapping. Defaults to None. functions (Union[None, List[Dict]], optional): Decrpcated. functions to use, each function is a JSON Schema. Defaults to None. function_call (str, optional): Decrpcated. method to call the function. Defaults to None. Choices: ['auto', '$NameOfTheFunction', 'none'] - name2func (Union[None, Dict], optional): Decrpcated. name to function mapping. Defaults to None. + Raises: ValueError: msg should be a list of dict, a string or None @@ -74,9 +73,9 @@ def __init__( self assert isinstance(functions, list), "functions should be a list of dict" if tools is not None: assert isinstance(tools, list), "tools should be a list of dict" - self._functions, self._function_call = functions, function_call - self._tools, self._tool_choice = tools, tool_choice - self._name2func, self._resp, self._name2tool = name2func, None, name2tool + self.functions, self.tools = functions or [], tools or [] + self._function_call, self._tool_choice = function_call, tool_choice + self._name2func, self._resp = name2func, None # Part1: basic operation of the chat object def add(self, role:str, **kwargs): @@ -127,7 +126,6 @@ def deepcopy(self): , tools=self.tools , tool_choice=self.tool_choice , name2func=self.name2func - , name2tool=self.name2tool , api_base=self.api_base , base_url=self.base_url) @@ -209,7 +207,11 @@ def getresponse( self , timeinterval:int = 0 , update:bool = True , stream:bool = False + , tools:Union[None, List[Dict]]=None + , tool_choice:Union[None, str]=None , max_requests:int=-1 + , functions:Union[None, List[Dict]]=None + , function_call:Union[None, str]=None , **options)->Resp: """Get the API response @@ -227,10 +229,13 @@ def getresponse( self # initialize data api_key, chat_url = self.api_key, self.chat_url if 'model' not in options: options['model'] = self.model - funcs = options.get('functions', self.functions) - func_call = options.get('function_call', self.function_call) - tools = options.get('tools', self._tools) - tool_choice = options.get('tool_choice', self._tool_choice) + # function call & tool call + tool_choice, tools = tool_choice or self.tool_choice, tools or self.tools + function_call, functions = function_call or self.function_call, functions or self.functions + if tool_choice is not None: + options['tool_choice'], options['tools'] = tool_choice, tools + elif function_call is not None: + options['function_call'], options['functions'] = function_call, functions # if api_key is None: warnings.warn("API key is not set!") msg, resp, numoftries = self.chat_log, None, 0 max_tries = max(max_tries, max_requests) @@ -240,10 +245,6 @@ def getresponse( self while max_tries: try: # make API Call - if funcs is not None: options['functions'] = funcs - if func_call is not None: options['function_call'] = func_call - if tools is not None: options['tools'] = tools - if tool_choice is not None: options['tool_choice'] = tool_choice response = chat_completion( api_key=api_key, messages=msg, chat_url=chat_url, timeout=timeout, **options) @@ -311,11 +312,9 @@ def setfuncs(self, funcs:List): def settools(self, tools:List): """Initialize tools for tool calls""" - self.tools = [{ - 'type':'function', - 'function': generate_json_schema(tool)} for tool in tools] + self._functions =[generate_json_schema(func) for func in tools] self.tool_choice = 'auto' - self.name2tool = {tool.__name__:tool for tool in tools} + self.name2func = {tool.__name__:tool for tool in tools} return True def calltools(self, display:bool=False): @@ -335,14 +334,14 @@ def calltool(self, tool): """Call the tool""" tool_call_id = tool['id'] tool_name, tool_para = tool['function']['name'], tool['function']['arguments'] - if tool_name not in self.name2tool: + if tool_name not in self.name2func: return f"Tool {tool_name} not found.", tool_name, tool_call_id, False try: tool_args = json.loads(tool_para) except Exception as e: return f"Argument parsing failed with error: {e}", tool_name, tool_call_id, False try: - result = self.name2tool[tool_name](**tool_args) + result = self.name2func[tool_name](**tool_args) except Exception as e: return f"Tool {tool_name} failed with error: {e}", tool_name, tool_call_id, False # succeed finally! @@ -382,9 +381,9 @@ def autoresponse( self bool: whether the response is finished """ if use_tool: - options['tools'], options['tool_choice'] = self.tools, self.tool_choice + options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto' else: - options['functions'], options['function_call'] = self.functions, self.function_call + options['functions'], options['function_call'] = self.functions, self.function_call or 'auto' show = lambda msg: print(self.display_role_content(msg)) resp = self.getresponse(**options) if display: show(resp.message) @@ -476,7 +475,9 @@ def function_call(self): @property def tools(self): """Get tools""" - return self._tools + if self.functions is not None: + return [{'type':'function', 'function': func} for func in self.functions] + return None @property def tool_choice(self): @@ -487,11 +488,6 @@ def tool_choice(self): def name2func(self): """Get name to function mapping""" return self._name2func - - @property - def name2tool(self): - """Get name to tool mapping""" - return self._name2tool @api_key.setter def api_key(self, api_key:str): @@ -538,7 +534,7 @@ def function_call(self, function_call:str): def tools(self, tools:List[Dict]): """Set tools""" assert isinstance(tools, list), "tools should be a list of dict" - self._tools = tools + self._functions = [tool['function'] for tool in tools] @tool_choice.setter def tool_choice(self, tool_choice:str): @@ -557,11 +553,6 @@ def name2func(self, name2func:Dict): assert isinstance(name2func, dict), "name2func should be a dict" self._name2func = name2func - @name2tool.setter - def name2tool(self, name2tool:Dict): - """Set name to tool mapping""" - assert isinstance(name2tool, dict), "name2tool should be a dict" - self._name2tool = name2tool @property def chat_log(self): diff --git a/setup.py b/setup.py index 3bdfc3c..90e6b14 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '3.1.7' +VERSION = '3.2.0' requirements = [ 'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8', diff --git a/tests/test_function.py b/tests/test_function.py index 3e8c3e3..34908fe 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -89,6 +89,17 @@ def mult(a:int, b:int) -> int: """ return a * b +def test_mix_function_tool(): + chat = Chat("find the sum of 784359345 and 345345345") + chat.setfuncs([add]) + chat.autoresponse(max_tries=3, display=True, timeinterval=2) + chat.clear() + chat.user("find the sum of 784359345 and 345345345") + chat.autoresponse(use_tool=False) + newchat = Chat("find the product of 123124 and 399090") + newchat.settools([mult]) + newchat.autoresponse() + def test_add_and_mult(): functions = [generate_json_schema(add)] chat = Chat("find the sum of 784359345 and 345345345") @@ -118,8 +129,8 @@ def test_add_and_mult(): def test_use_exec_function(): chat = Chat("find the result of sqrt(121314)") chat.setfuncs([exec_python_code]) - chat.autoresponse(max_tries=2, display=True, use_tool=False) - + # chat.autoresponse(max_tries=2, display=True, use_tool=False) + def test_find_permutation_group(): pass diff --git a/tests/test_tools.py b/tests/test_tools.py index b6dead5..caf76fa 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -30,7 +30,7 @@ "forecast": ["sunny", "windy"], "unit":"celsius" } -name2tool = { +name2func = { 'get_current_weather': lambda *kargs, **kwargs: weatherinfo } @@ -55,7 +55,7 @@ def test_call_weather(): def test_auto_response(): chat = Chat("What's the weather like in Boston?") chat.tools, chat.tool_choice = tools, 'auto' - chat.name2tool = name2tool + chat.name2func = name2func chat.autoresponse(max_tries=2, display=True) chat.print_log() newchat = chat.deepcopy() @@ -103,7 +103,7 @@ def test_add_and_mult(): chat.tool_choice = {'name':'add'} chat.tool_choice = 'add' # specify the function(convert to dict) chat.tools = tools - chat.name2tool = {'add': add} # dictionary of functions + chat.name2func = {'add': add} # dictionary of functions chat.tool_choice = 'auto' # auto decision # run until success: maxturns=-1 chat.autoresponse(max_tries=3, display=True, timeinterval=2)