Skip to content

Commit

Permalink
Rex/fixbugs (#78)
Browse files Browse the repository at this point in the history
* fix async response for some case

* mask keys in debug log
  • Loading branch information
RexWzh authored May 14, 2024
1 parent 06e8f79 commit fce4f40
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 25 deletions.
30 changes: 13 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
[English](README_en.md) | [简体中文](README.md)
</div>

基于 API 的简单封装,支持多轮对话,代理,以及异步处理数据等。
基于 OpenAI API `Chat` 对象,支持多轮对话,代理,以及异步处理数据等。

## 安装方法

Expand All @@ -39,21 +39,19 @@ export OPENAI_API_BASE="https://api.example.com/v1"
export OPENAI_API_BASE_URL="https://api.example.com" # 可选
```

Win 在系统中设置环境变量。

也可以在代码中设置:
或者在代码中设置:

```py
import chattool
chattool.api_key = "sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
chattool.api_base = "https://api.example.com/v1"
```

注:`OPENAI_API_BASE` 优先于 `OPENAI_API_BASE_URL`,二者选其一即可。
注:环境变量 `OPENAI_API_BASE` 优先于 `OPENAI_API_BASE_URL`,二者选其一即可。

### 示例

示例1,模拟多轮对话
示例1,多轮对话

```python
# 初次对话
Expand All @@ -75,7 +73,7 @@ chat.save("chat.json", mode="w") # 默认为 "a"
chat.print_log()
```

示例2,批量处理数据(串行),并使用缓存文件 `checkpoint`
示例2,批量处理数据(串行),并使用缓存文件 `chat.jsonl`

```python
# 编写处理函数
Expand Down Expand Up @@ -109,7 +107,7 @@ async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=dat
chats = load_chats("async_chat.jsonl")
```

在 Jupyter Notebook 中运行,需要使用 `await` 关键字和 `wait=True` 参数:
在 Jupyter Notebook 中运行,因其[特殊机制](https://stackoverflow.com/questions/47518874/how-do-i-run-python-asyncio-code-in-a-jupyter-notebook),需使用 `await` 关键字和 `wait=True` 参数:

```python
await async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=data2chat, wait=True)
Expand Down Expand Up @@ -145,15 +143,13 @@ chat.autoresponse(display=True)

## 更新日志

当前版本为 `2.3.0`,支持调用外部工具,异步处理数据,以及模型微调功能。

### 测试版本
- 版本 `0.2.0` 改用 `Chat` 类型作为中心交互对象
- 版本 `2.3.0`,支持调用外部工具,异步处理数据,以及模型微调功能
- 版本 `2.0.0` 开始,更名为 `chattool`
- 版本 `1.0.0` 开始,支持异步处理数据
- 版本 `0.6.0` 开始,支持 [function call](https://platform.openai.com/docs/guides/gpt/function-calling) 功能
- 版本 `0.5.0` 开始,支持使用 `process_chats` 处理数据,借助 `msg2chat` 函数以及 `checkpoint` 文件
- 版本 `0.4.0` 开始,工具维护转至 [CubeNLP](https://github.com/cubenlp) 组织账号
- 版本 `0.3.0` 开始不依赖模块 `openai.py` ,而是直接使用 `requests` 发送请求
- 支持对每个 `Chat` 使用不同 API 密钥
- 支持使用代理链接
- 版本 `0.4.0` 开始,工具维护转至 [CubeNLP](https://github.com/cubenlp) 组织账号
- 版本 `0.5.0` 开始,支持使用 `process_chats` 处理数据,借助 `msg2chat` 函数以及 `checkpoint` 文件
- 版本 `0.6.0` 开始,支持 [function call](https://platform.openai.com/docs/guides/gpt/function-calling) 功能
- 版本 `1.0.0` 开始,支持异步处理数据
- 版本 `2.0.0` 开始,模块更名为 `chattool`
- 版本 `0.2.0` 改用 `Chat` 类型作为中心交互对象
25 changes: 21 additions & 4 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '3.1.5'
__version__ = '3.1.6'

import os, sys, requests
from .chattype import Chat, Resp
Expand Down Expand Up @@ -64,6 +64,7 @@ def save_envs(env_file:str):
load_envs()

# get the platform
# tqdm.asyncio.tqdm.gather differs on different platforms
platform = sys.platform
if platform.startswith("win"):
platform = "windows"
Expand Down Expand Up @@ -91,6 +92,24 @@ def get_valid_models(api_key:str=api_key, base_url:str=base_url):
"""
return request.valid_models(api_key, base_url)

def print_secure_api_key(api_key):
if api_key:
length = len(api_key)
if length == 1 or length == 2:
masked_key = '*' * length
elif 3 <= length <= 6:
masked_key = api_key[0] + '*' * (length - 2) + api_key[-1]
elif 7 <= length <= 14:
masked_key = api_key[:2] + '*' * (length - 4) + api_key[-2:]
elif 15 <= length <= 30:
masked_key = api_key[:4] + '*' * (length - 8) + api_key[-4:]
else:
masked_key = api_key[:8] + '*' * (length - 12) + api_key[-8:]
print("\nPlease verify your API key:")
print(masked_key)
else:
print("No API key provided.")

def debug_log( net_url:str="https://www.baidu.com"
, timeout:int=5
, message:str="hello world! 你好!"
Expand Down Expand Up @@ -134,8 +153,7 @@ def debug_log( net_url:str="https://www.baidu.com"

## Please check your API key
if test_apikey:
print("\nPlease verify your API key:")
print(api_key)
print_secure_api_key(api_key)

# Get model list
if test_model:
Expand All @@ -151,4 +169,3 @@ def debug_log( net_url:str="https://www.baidu.com"

print("\nDebug is finished.")
return True

6 changes: 3 additions & 3 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def add(self, role:str, **kwargs):
self._chat_log.append({'role':role, **kwargs})
return self

def user(self, content:str):
def user(self, content: Union[List, str]):
"""User message"""
return self.add('user', content=content)

Expand Down Expand Up @@ -511,11 +511,11 @@ async def _async_stream_responses( api_key:str
try:
# wrap the response
resp = Resp(json.loads(strline))
# stop if the response is finished
if resp.finish_reason == 'stop': break
# deal with the message
if 'content' not in resp.delta: continue
yield resp
# stop if the response is finished
if resp.finish_reason == 'stop': break
except Exception as e:
print(f"Error: {e}, line: {strline}")
break
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.1.5'
VERSION = '3.1.6'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down

0 comments on commit fce4f40

Please sign in to comment.