Skip to content

Commit

Permalink
Merge pull request #81 from cubenlp/rex/specify-tool-type
Browse files Browse the repository at this point in the history
specify tool type
  • Loading branch information
RexWzh authored Jun 15, 2024
2 parents 168ccbe + 30aed09 commit 865ebda
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 35 deletions.
24 changes: 18 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,32 @@ def add(a: int, b: int) -> int:
int: The sum of the two numbers.
"""
return a + b

def mult(a:int, b:int) -> int:
"""This function multiplies two numbers.
It is a useful calculator!
Args:
a (int): The first number.
b (int): The second number.
Returns:
int: The product of the two numbers.
"""
return a * b
# 传输函数
chat = Chat()
chat.setfuncs([add]) # 传入函数列表,可以是多个函数
chat.user("请计算 1 + 2")
# 自动调用工具
chat.autoresponse(display=True)
chat = Chat("find the value of (23723 * 1322312 ) + 12312") # 传入函数列表,可以是多个函数
# 自动调用工具,默认使用 tool_choice
chat.autoresponse(display=True, tool_choice='tool_choice') # 或者用 function_call
```

## 开源协议

这个项目使用 MIT 协议开源。
使用 MIT 协议开源。

## 更新日志

- 当前版本 `3.2.1`,简化异步处理和串行处理的接口,更新子模块名称,避免冲突
- 版本 `2.3.0`,支持调用外部工具,异步处理数据,以及模型微调功能
- 版本 `2.0.0` 开始,更名为 `chattool`
- 版本 `1.0.0` 开始,支持异步处理数据
Expand Down
28 changes: 21 additions & 7 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '3.2.0'
__version__ = '3.2.1'

import os, sys, requests
from .chattype import Chat, Resp
Expand All @@ -13,7 +13,8 @@
from .asynctool import async_chat_completion
from .functioncall import generate_json_schema, exec_python_code
from typing import Union
import dotenv
import dotenv
import loguru

raw_env_text = f"""# Description: Env file for ChatTool.
# Current version: {__version__}
Expand All @@ -33,11 +34,24 @@
"""

def load_envs(env:Union[None, str, dict]=None):
"""Read the environment variables for the API call"""
"""Load the environment variables for the API call
Args:
env (Union[None, str, dict], optional): The environment file or the environment variables. Defaults to None.
Returns:
bool: True if the environment variables are loaded successfully.
Examples:
load_envs("envfile.env")
load_envs({"OPENAI_API_KEY":"your_api_key"})
load_envs() # load from the environment variables
"""
global api_key, base_url, api_base, model
# update the environment variables
if isinstance(env, str):
dotenv.load_dotenv(env, override=True)
if isinstance(env, str) and not dotenv.load_dotenv(env, override=True):
loguru.logger.warning(f"Failed to load the environment file: {env}")
return False
elif isinstance(env, dict):
for key, value in env.items():
os.environ[key] = value
Expand Down Expand Up @@ -157,8 +171,8 @@ def debug_log( net_url:str="https://www.baidu.com"

# Get model list
if test_model:
print("\nThe model list(contains gpt):")
print(Chat().get_valid_models())
print("\nThe model list:")
print(Chat().get_valid_models(gpt_only=False))

# Test hello world
if test_response:
Expand Down
40 changes: 23 additions & 17 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from .functioncall import generate_json_schema, delete_dialogue_assist
from pprint import pformat
from loguru import logger

class Chat():
def __init__( self
Expand Down Expand Up @@ -113,11 +114,11 @@ def clear(self):

def copy(self):
"""Copy the chat log"""
return Chat(self._chat_log)
return Chat(self.chat_log)

def deepcopy(self):
"""Deep copy the Chat object"""
return Chat( self._chat_log
return Chat( self.chat_log
, api_key=self.api_key
, chat_url=self.chat_url
, model=self.model
Expand Down Expand Up @@ -210,6 +211,7 @@ def getresponse( self
, tools:Union[None, List[Dict]]=None
, tool_choice:Union[None, str]=None
, max_requests:int=-1
, tool_type:str='tool_choice'
, functions:Union[None, List[Dict]]=None
, function_call:Union[None, str]=None
, **options)->Resp:
Expand All @@ -221,7 +223,8 @@ def getresponse( self
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
update (bool, optional): whether to update the chat log. Defaults to True.
options (dict, optional): other options like `temperature`, `top_p`, etc.
max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit
max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit)
tool_type (str, optional): type of the tool. Defaults to 'tool_choice'.
Returns:
Resp: API response
Expand All @@ -232,10 +235,14 @@ def getresponse( self
# function call & tool call
tool_choice, tools = tool_choice or self.tool_choice, tools or self.tools
function_call, functions = function_call or self.function_call, functions or self.functions
if tool_choice is not None:
options['tool_choice'], options['tools'] = tool_choice, tools
elif function_call is not None:
options['function_call'], options['functions'] = function_call, functions
if tool_type == 'function_call':
if function_call is not None:
options['function_call'], options['functions'] = function_call, functions
else:
if tool_choice is not None:
if tool_type != 'tool_choice':
logger.warning(f"Unknown tool type {tool_type}, use 'tool_choice' by default.")
options['tool_choice'], options['tools'] = tool_choice, tools
# if api_key is None: warnings.warn("API key is not set!")
msg, resp, numoftries = self.chat_log, None, 0
max_tries = max(max_tries, max_requests)
Expand Down Expand Up @@ -312,8 +319,8 @@ def setfuncs(self, funcs:List):

def settools(self, tools:List):
"""Initialize tools for tool calls"""
self._functions =[generate_json_schema(func) for func in tools]
self.tool_choice = 'auto'
self.functions =[generate_json_schema(func) for func in tools]
self.tool_choice = 'auto' # the only difference from setfuncs
self.name2func = {tool.__name__:tool for tool in tools}
return True

Expand Down Expand Up @@ -368,7 +375,7 @@ def callfunction(self):
def autoresponse( self
, display:bool=False
, maxturns:int=3
, use_tool:bool=True
, tool_type:str='tool_choice'
, **options):
"""Get the response automatically
Expand All @@ -380,22 +387,22 @@ def autoresponse( self
Returns:
bool: whether the response is finished
"""
if use_tool:
options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto'
else:
if tool_type == 'function_call':
options['functions'], options['function_call'] = self.functions, self.function_call or 'auto'
else:
options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto'
show = lambda msg: print(self.display_role_content(msg))
resp = self.getresponse(**options)
resp = self.getresponse(tool_type=tool_type, **options)
if display: show(resp.message)
while self.iswaiting() and maxturns != 0:
# call api and update the result
if use_tool:
if tool_type != 'function_call':
self.calltools(display=display)
else:
result, name, _ = self.callfunction()
self.function(result, name)
if display: show(self[-1])
resp = self.getresponse(**options)
resp = self.getresponse(tool_type=tool_type, **options)
if display: show(resp.message)
maxturns -= 1
return True
Expand Down Expand Up @@ -553,7 +560,6 @@ def name2func(self, name2func:Dict):
assert isinstance(name2func, dict), "name2func should be a dict"
self._name2func = name2func


@property
def chat_log(self):
"""Chat history"""
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.2.0'
VERSION = '3.2.1'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
'tqdm>=4.60', 'docstring_parser>=0.10', "python-dotenv>=0.17.0"]
'tqdm>=4.60', 'docstring_parser>=0.10', "python-dotenv>=0.17.0",
'loguru>=0.7']
test_requirements = ['pytest>=3', 'unittest']

setup(
Expand Down
1 change: 0 additions & 1 deletion tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def add(a: int, b: int) -> int:
"""
return a + b

# with optional parameters
def mult(a:int, b:int) -> int:
"""This function multiplies two numbers.
It is a useful calculator!
Expand Down
17 changes: 15 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def add(a: int, b: int) -> int:
"""
return a + b

# with optional parameters
def mult(a:int, b:int) -> int:
"""This function multiplies two numbers.
It is a useful calculator!
Expand All @@ -92,6 +91,21 @@ def mult(a:int, b:int) -> int:
"""
return a * b

def test_func_and_tool():
chat = Chat("find the value of 124842 * 3423424 + 121312")
chat.settools([add, mult]) # multi choice
chat1 = chat.deepcopy()
chat1.autoresponse(tool_type='tool_choice')
chat2 = chat.deepcopy()
chat2.autoresponse(tool_type='function_call')
# setfuncs
chat.setfuncs([add, mult])
chat1 = chat.deepcopy()
chat1.autoresponse(tool_type='tool_choice')
chat2 = chat.deepcopy()
chat2.autoresponse(tool_type='function_call')


def test_add_and_mult():
tools = [{
'type':'function',
Expand Down Expand Up @@ -127,7 +141,6 @@ def test_add_and_mult():
chat4.user("find the value of (23723 * 1322312 ) + 12312")
chat4.autoresponse(max_tries=3, display=True, timeinterval=2)


def test_use_exec_function():
chat = Chat("find the result of sqrt(121314)")
chat.settools([exec_python_code])
Expand Down

0 comments on commit 865ebda

Please sign in to comment.