Skip to content

Commit

Permalink
simplify tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Jun 8, 2024
1 parent a6ed9ab commit 735550f
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 42 deletions.
2 changes: 1 addition & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '3.1.7'
__version__ = '3.2.0'

import os, sys, requests
from .chattype import Chat, Resp
Expand Down
61 changes: 26 additions & 35 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__( self
, model:Union[None, str]=None
, tools:Union[None, List[Dict]]=None
, tool_choice:Union[None, str]=None
, name2tool:Union[None, Dict]=None
, functions:Union[None, List[Dict]]=None
, function_call:Union[None, str]=None
, name2func:Union[None, Dict]=None):
Expand All @@ -35,10 +34,10 @@ def __init__( self
model (Union[None, str], optional): model to use. Defaults to None.
tools (Union[None, List[Dict]], optional): tools to use, each tool is a JSON Schema. Defaults to None.
tool_choice (Union[None, str], optional): method to choose the tool. Defaults to None. Choices: ['auto', '$NameOfTheTool', 'none']
name2tool (Union[None, Dict], optional): name to tool mapping. Defaults to None.
name2func (Union[None, Dict], optional): name to function mapping. Defaults to None.
functions (Union[None, List[Dict]], optional): Decrpcated. functions to use, each function is a JSON Schema. Defaults to None.
function_call (str, optional): Decrpcated. method to call the function. Defaults to None. Choices: ['auto', '$NameOfTheFunction', 'none']
name2func (Union[None, Dict], optional): Decrpcated. name to function mapping. Defaults to None.
Raises:
ValueError: msg should be a list of dict, a string or None
Expand Down Expand Up @@ -74,9 +73,9 @@ def __init__( self
assert isinstance(functions, list), "functions should be a list of dict"
if tools is not None:
assert isinstance(tools, list), "tools should be a list of dict"
self._functions, self._function_call = functions, function_call
self._tools, self._tool_choice = tools, tool_choice
self._name2func, self._resp, self._name2tool = name2func, None, name2tool
self.functions, self.tools = functions or [], tools or []
self._function_call, self._tool_choice = function_call, tool_choice
self._name2func, self._resp = name2func, None

# Part1: basic operation of the chat object
def add(self, role:str, **kwargs):
Expand Down Expand Up @@ -127,7 +126,6 @@ def deepcopy(self):
, tools=self.tools
, tool_choice=self.tool_choice
, name2func=self.name2func
, name2tool=self.name2tool
, api_base=self.api_base
, base_url=self.base_url)

Expand Down Expand Up @@ -209,7 +207,11 @@ def getresponse( self
, timeinterval:int = 0
, update:bool = True
, stream:bool = False
, tools:Union[None, List[Dict]]=None
, tool_choice:Union[None, str]=None
, max_requests:int=-1
, functions:Union[None, List[Dict]]=None
, function_call:Union[None, str]=None
, **options)->Resp:
"""Get the API response
Expand All @@ -227,10 +229,13 @@ def getresponse( self
# initialize data
api_key, chat_url = self.api_key, self.chat_url
if 'model' not in options: options['model'] = self.model
funcs = options.get('functions', self.functions)
func_call = options.get('function_call', self.function_call)
tools = options.get('tools', self._tools)
tool_choice = options.get('tool_choice', self._tool_choice)
# function call & tool call
tool_choice, tools = tool_choice or self.tool_choice, tools or self.tools
function_call, functions = function_call or self.function_call, functions or self.functions
if tool_choice is not None:
options['tool_choice'], options['tools'] = tool_choice, tools
elif function_call is not None:
options['function_call'], options['functions'] = function_call, functions
# if api_key is None: warnings.warn("API key is not set!")
msg, resp, numoftries = self.chat_log, None, 0
max_tries = max(max_tries, max_requests)
Expand All @@ -240,10 +245,6 @@ def getresponse( self
while max_tries:
try:
# make API Call
if funcs is not None: options['functions'] = funcs
if func_call is not None: options['function_call'] = func_call
if tools is not None: options['tools'] = tools
if tool_choice is not None: options['tool_choice'] = tool_choice
response = chat_completion(
api_key=api_key, messages=msg,
chat_url=chat_url, timeout=timeout, **options)
Expand Down Expand Up @@ -311,11 +312,9 @@ def setfuncs(self, funcs:List):

def settools(self, tools:List):
"""Initialize tools for tool calls"""
self.tools = [{
'type':'function',
'function': generate_json_schema(tool)} for tool in tools]
self._functions =[generate_json_schema(func) for func in tools]
self.tool_choice = 'auto'
self.name2tool = {tool.__name__:tool for tool in tools}
self.name2func = {tool.__name__:tool for tool in tools}
return True

def calltools(self, display:bool=False):
Expand All @@ -335,14 +334,14 @@ def calltool(self, tool):
"""Call the tool"""
tool_call_id = tool['id']
tool_name, tool_para = tool['function']['name'], tool['function']['arguments']
if tool_name not in self.name2tool:
if tool_name not in self.name2func:
return f"Tool {tool_name} not found.", tool_name, tool_call_id, False
try:
tool_args = json.loads(tool_para)
except Exception as e:
return f"Argument parsing failed with error: {e}", tool_name, tool_call_id, False
try:
result = self.name2tool[tool_name](**tool_args)
result = self.name2func[tool_name](**tool_args)
except Exception as e:
return f"Tool {tool_name} failed with error: {e}", tool_name, tool_call_id, False
# succeed finally!
Expand Down Expand Up @@ -382,9 +381,9 @@ def autoresponse( self
bool: whether the response is finished
"""
if use_tool:
options['tools'], options['tool_choice'] = self.tools, self.tool_choice
options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto'
else:
options['functions'], options['function_call'] = self.functions, self.function_call
options['functions'], options['function_call'] = self.functions, self.function_call or 'auto'
show = lambda msg: print(self.display_role_content(msg))
resp = self.getresponse(**options)
if display: show(resp.message)
Expand Down Expand Up @@ -476,7 +475,9 @@ def function_call(self):
@property
def tools(self):
"""Get tools"""
return self._tools
if self.functions is not None:
return [{'type':'function', 'function': func} for func in self.functions]
return None

@property
def tool_choice(self):
Expand All @@ -487,11 +488,6 @@ def tool_choice(self):
def name2func(self):
"""Get name to function mapping"""
return self._name2func

@property
def name2tool(self):
"""Get name to tool mapping"""
return self._name2tool

@api_key.setter
def api_key(self, api_key:str):
Expand Down Expand Up @@ -538,7 +534,7 @@ def function_call(self, function_call:str):
def tools(self, tools:List[Dict]):
"""Set tools"""
assert isinstance(tools, list), "tools should be a list of dict"
self._tools = tools
self._functions = [tool['function'] for tool in tools]

@tool_choice.setter
def tool_choice(self, tool_choice:str):
Expand All @@ -557,11 +553,6 @@ def name2func(self, name2func:Dict):
assert isinstance(name2func, dict), "name2func should be a dict"
self._name2func = name2func

@name2tool.setter
def name2tool(self, name2tool:Dict):
"""Set name to tool mapping"""
assert isinstance(name2tool, dict), "name2tool should be a dict"
self._name2tool = name2tool

@property
def chat_log(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.1.7'
VERSION = '3.2.0'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
15 changes: 13 additions & 2 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ def mult(a:int, b:int) -> int:
"""
return a * b

def test_mix_function_tool():
chat = Chat("find the sum of 784359345 and 345345345")
chat.setfuncs([add])
chat.autoresponse(max_tries=3, display=True, timeinterval=2)
chat.clear()
chat.user("find the sum of 784359345 and 345345345")
chat.autoresponse(use_tool=False)
newchat = Chat("find the product of 123124 and 399090")
newchat.settools([mult])
newchat.autoresponse()

def test_add_and_mult():
functions = [generate_json_schema(add)]
chat = Chat("find the sum of 784359345 and 345345345")
Expand Down Expand Up @@ -118,8 +129,8 @@ def test_add_and_mult():
def test_use_exec_function():
chat = Chat("find the result of sqrt(121314)")
chat.setfuncs([exec_python_code])
chat.autoresponse(max_tries=2, display=True, use_tool=False)
# chat.autoresponse(max_tries=2, display=True, use_tool=False)

def test_find_permutation_group():
pass

Expand Down
6 changes: 3 additions & 3 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"forecast": ["sunny", "windy"],
"unit":"celsius"
}
name2tool = {
name2func = {
'get_current_weather': lambda *kargs, **kwargs: weatherinfo
}

Expand All @@ -55,7 +55,7 @@ def test_call_weather():
def test_auto_response():
chat = Chat("What's the weather like in Boston?")
chat.tools, chat.tool_choice = tools, 'auto'
chat.name2tool = name2tool
chat.name2func = name2func
chat.autoresponse(max_tries=2, display=True)
chat.print_log()
newchat = chat.deepcopy()
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_add_and_mult():
chat.tool_choice = {'name':'add'}
chat.tool_choice = 'add' # specify the function(convert to dict)
chat.tools = tools
chat.name2tool = {'add': add} # dictionary of functions
chat.name2func = {'add': add} # dictionary of functions
chat.tool_choice = 'auto' # auto decision
# run until success: maxturns=-1
chat.autoresponse(max_tries=3, display=True, timeinterval=2)
Expand Down

0 comments on commit 735550f

Please sign in to comment.