Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

specify tool type #81

Merged
merged 1 commit into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,32 @@ def add(a: int, b: int) -> int:
int: The sum of the two numbers.
"""
return a + b

def mult(a:int, b:int) -> int:
"""This function multiplies two numbers.
It is a useful calculator!

Args:
a (int): The first number.
b (int): The second number.

Returns:
int: The product of the two numbers.
"""
return a * b
# 传输函数
chat = Chat()
chat.setfuncs([add]) # 传入函数列表,可以是多个函数
chat.user("请计算 1 + 2")
# 自动调用工具
chat.autoresponse(display=True)
chat = Chat("find the value of (23723 * 1322312 ) + 12312") # 传入函数列表,可以是多个函数
# 自动调用工具,默认使用 tool_choice
chat.autoresponse(display=True, tool_choice='tool_choice') # 或者用 function_call
```

## 开源协议

这个项目使用 MIT 协议开源。
使用 MIT 协议开源。

## 更新日志

- 当前版本 `3.2.1`,简化异步处理和串行处理的接口,更新子模块名称,避免冲突
- 版本 `2.3.0`,支持调用外部工具,异步处理数据,以及模型微调功能
- 版本 `2.0.0` 开始,更名为 `chattool`
- 版本 `1.0.0` 开始,支持异步处理数据
Expand Down
28 changes: 21 additions & 7 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '3.2.0'
__version__ = '3.2.1'

import os, sys, requests
from .chattype import Chat, Resp
Expand All @@ -13,7 +13,8 @@
from .asynctool import async_chat_completion
from .functioncall import generate_json_schema, exec_python_code
from typing import Union
import dotenv
import dotenv
import loguru

raw_env_text = f"""# Description: Env file for ChatTool.
# Current version: {__version__}
Expand All @@ -33,11 +34,24 @@
"""

def load_envs(env:Union[None, str, dict]=None):
"""Read the environment variables for the API call"""
"""Load the environment variables for the API call
Args:
env (Union[None, str, dict], optional): The environment file or the environment variables. Defaults to None.
Returns:
bool: True if the environment variables are loaded successfully.
Examples:
load_envs("envfile.env")
load_envs({"OPENAI_API_KEY":"your_api_key"})
load_envs() # load from the environment variables
"""
global api_key, base_url, api_base, model
# update the environment variables
if isinstance(env, str):
dotenv.load_dotenv(env, override=True)
if isinstance(env, str) and not dotenv.load_dotenv(env, override=True):
loguru.logger.warning(f"Failed to load the environment file: {env}")
return False
elif isinstance(env, dict):
for key, value in env.items():
os.environ[key] = value
Expand Down Expand Up @@ -157,8 +171,8 @@ def debug_log( net_url:str="https://www.baidu.com"

# Get model list
if test_model:
print("\nThe model list(contains gpt):")
print(Chat().get_valid_models())
print("\nThe model list:")
print(Chat().get_valid_models(gpt_only=False))

# Test hello world
if test_response:
Expand Down
40 changes: 23 additions & 17 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from .functioncall import generate_json_schema, delete_dialogue_assist
from pprint import pformat
from loguru import logger

class Chat():
def __init__( self
Expand Down Expand Up @@ -113,11 +114,11 @@ def clear(self):

def copy(self):
"""Copy the chat log"""
return Chat(self._chat_log)
return Chat(self.chat_log)

def deepcopy(self):
"""Deep copy the Chat object"""
return Chat( self._chat_log
return Chat( self.chat_log
, api_key=self.api_key
, chat_url=self.chat_url
, model=self.model
Expand Down Expand Up @@ -210,6 +211,7 @@ def getresponse( self
, tools:Union[None, List[Dict]]=None
, tool_choice:Union[None, str]=None
, max_requests:int=-1
, tool_type:str='tool_choice'
, functions:Union[None, List[Dict]]=None
, function_call:Union[None, str]=None
, **options)->Resp:
Expand All @@ -221,7 +223,8 @@ def getresponse( self
timeinterval (int, optional): time interval between two API calls. Defaults to 0.
update (bool, optional): whether to update the chat log. Defaults to True.
options (dict, optional): other options like `temperature`, `top_p`, etc.
max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit
max_requests (int, optional): (deprecated) maximum number of requests to make. Defaults to -1(no limit)
tool_type (str, optional): type of the tool. Defaults to 'tool_choice'.
Returns:
Resp: API response
Expand All @@ -232,10 +235,14 @@ def getresponse( self
# function call & tool call
tool_choice, tools = tool_choice or self.tool_choice, tools or self.tools
function_call, functions = function_call or self.function_call, functions or self.functions
if tool_choice is not None:
options['tool_choice'], options['tools'] = tool_choice, tools
elif function_call is not None:
options['function_call'], options['functions'] = function_call, functions
if tool_type == 'function_call':
if function_call is not None:
options['function_call'], options['functions'] = function_call, functions
else:
if tool_choice is not None:
if tool_type != 'tool_choice':
logger.warning(f"Unknown tool type {tool_type}, use 'tool_choice' by default.")
options['tool_choice'], options['tools'] = tool_choice, tools
# if api_key is None: warnings.warn("API key is not set!")
msg, resp, numoftries = self.chat_log, None, 0
max_tries = max(max_tries, max_requests)
Expand Down Expand Up @@ -312,8 +319,8 @@ def setfuncs(self, funcs:List):

def settools(self, tools:List):
"""Initialize tools for tool calls"""
self._functions =[generate_json_schema(func) for func in tools]
self.tool_choice = 'auto'
self.functions =[generate_json_schema(func) for func in tools]
self.tool_choice = 'auto' # the only difference from setfuncs
self.name2func = {tool.__name__:tool for tool in tools}
return True

Expand Down Expand Up @@ -368,7 +375,7 @@ def callfunction(self):
def autoresponse( self
, display:bool=False
, maxturns:int=3
, use_tool:bool=True
, tool_type:str='tool_choice'
, **options):
"""Get the response automatically
Expand All @@ -380,22 +387,22 @@ def autoresponse( self
Returns:
bool: whether the response is finished
"""
if use_tool:
options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto'
else:
if tool_type == 'function_call':
options['functions'], options['function_call'] = self.functions, self.function_call or 'auto'
else:
options['tools'], options['tool_choice'] = self.tools, self.tool_choice or 'auto'
show = lambda msg: print(self.display_role_content(msg))
resp = self.getresponse(**options)
resp = self.getresponse(tool_type=tool_type, **options)
if display: show(resp.message)
while self.iswaiting() and maxturns != 0:
# call api and update the result
if use_tool:
if tool_type != 'function_call':
self.calltools(display=display)
else:
result, name, _ = self.callfunction()
self.function(result, name)
if display: show(self[-1])
resp = self.getresponse(**options)
resp = self.getresponse(tool_type=tool_type, **options)
if display: show(resp.message)
maxturns -= 1
return True
Expand Down Expand Up @@ -553,7 +560,6 @@ def name2func(self, name2func:Dict):
assert isinstance(name2func, dict), "name2func should be a dict"
self._name2func = name2func


@property
def chat_log(self):
"""Chat history"""
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.2.0'
VERSION = '3.2.1'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
'tqdm>=4.60', 'docstring_parser>=0.10', "python-dotenv>=0.17.0"]
'tqdm>=4.60', 'docstring_parser>=0.10', "python-dotenv>=0.17.0",
'loguru>=0.7']
test_requirements = ['pytest>=3', 'unittest']

setup(
Expand Down
1 change: 0 additions & 1 deletion tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def add(a: int, b: int) -> int:
"""
return a + b

# with optional parameters
def mult(a:int, b:int) -> int:
"""This function multiplies two numbers.
It is a useful calculator!
Expand Down
17 changes: 15 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def add(a: int, b: int) -> int:
"""
return a + b

# with optional parameters
def mult(a:int, b:int) -> int:
"""This function multiplies two numbers.
It is a useful calculator!
Expand All @@ -92,6 +91,21 @@ def mult(a:int, b:int) -> int:
"""
return a * b

def test_func_and_tool():
chat = Chat("find the value of 124842 * 3423424 + 121312")
chat.settools([add, mult]) # multi choice
chat1 = chat.deepcopy()
chat1.autoresponse(tool_type='tool_choice')
chat2 = chat.deepcopy()
chat2.autoresponse(tool_type='function_call')
# setfuncs
chat.setfuncs([add, mult])
chat1 = chat.deepcopy()
chat1.autoresponse(tool_type='tool_choice')
chat2 = chat.deepcopy()
chat2.autoresponse(tool_type='function_call')


def test_add_and_mult():
tools = [{
'type':'function',
Expand Down Expand Up @@ -127,7 +141,6 @@ def test_add_and_mult():
chat4.user("find the value of (23723 * 1322312 ) + 12312")
chat4.autoresponse(max_tries=3, display=True, timeinterval=2)


def test_use_exec_function():
chat = Chat("find the result of sqrt(121314)")
chat.settools([exec_python_code])
Expand Down
Loading