Skip to content

Commit

Permalink
adapt to webchatter
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Dec 20, 2023
1 parent c0370d6 commit e33f776
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 63 deletions.
2 changes: 1 addition & 1 deletion chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '[email protected]'
__version__ = '2.6.2'
__version__ = '3.0.0'

import os, sys, requests
from .chattool import Chat, Resp
Expand Down
18 changes: 9 additions & 9 deletions chattool/asynctool.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ async def async_process_msgs( chatlogs:List[List[Dict]]
sem = asyncio.Semaphore(ncoroutines)
locker = asyncio.Lock()

async def chat_complete(ind, locker, chatlog, chkpoint, **options):
payload = {"messages": chatlog}
async def chat_complete(ind, locker, chat_log, chkpoint, **options):
payload = {"messages": chat_log}
payload.update(options)
if max_tokens is not None:
payload['max_tokens'] = max_tokens(chatlog)
payload['max_tokens'] = max_tokens(chat_log)
data = json.dumps(payload)
resp = await async_post( session=session
, sem=sem
Expand All @@ -99,22 +99,22 @@ async def chat_complete(ind, locker, chatlog, chkpoint, **options):
, timeout=timeout)
## saving files
if resp is None: return 0, 0
chatlog.append(resp.message)
chat = Chat(chatlog)
chat_log.append(resp.message)
chat = Chat(chat_log)
async with locker: # locker | not necessary for normal IO
chat.savewithid(chkpoint, chatid=ind)
chat.save(chkpoint, index=ind)
return ind, resp.cost()

async with sem, aiohttp.ClientSession() as session:
tasks = []
for ind, chatlog in enumerate(chatlogs):
for ind, chat_log in enumerate(chatlogs):
if chats[ind] is not None: # skip completed chats
continue
tasks.append(
asyncio.create_task(
chat_complete( ind=ind
, locker=locker
, chatlog=chatlog
, chat_log=chat_log
, chkpoint=chkpoint
, **options)))
try: # for mac or linux
Expand Down Expand Up @@ -186,7 +186,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
# run async process
assert ncoroutines > 0, "ncoroutines must be greater than 0!"
if isinstance(max_tokens, int):
max_tokens = lambda chatlog: max_tokens
max_tokens = lambda chat_log: max_tokens
args = {
"chatlogs": chatlogs,
"chkpoint": chkpoint,
Expand Down
25 changes: 4 additions & 21 deletions chattool/chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def deepcopy(self):
, name2func=self.name2func
, base_url=self.base_url)

def save(self, path:str, mode:str='a'):
def save(self, path:str, mode:str='a', index:int=0):
"""
Save the chat log to a file. Each line is a json string.
Expand All @@ -115,24 +115,7 @@ def save(self, path:str, mode:str='a'):
pathname = os.path.dirname(path).strip()
if pathname != '':
os.makedirs(pathname, exist_ok=True)
with open(path, mode, encoding='utf-8') as f:
f.write(json.dumps(self.chat_log, ensure_ascii=False) + '\n')
return

def savewithid(self, path:str, chatid:int, mode:str='a'):
"""Save the chat log with chat id. Each line is a json string.
Args:
path (str): path to the file
chatid (int): chat id
mode (str, optional): mode to open the file. Defaults to 'a'.
"""
assert mode in ['a', 'w'], "saving mode should be 'a' or 'w'"
# make path if not exists
pathname = os.path.dirname(path).strip()
if pathname != '':
os.makedirs(pathname, exist_ok=True)
data = {"chatid": chatid, "chatlog": self.chat_log}
data = {"index": index, "chat_log": self.chat_log}
with open(path, mode, encoding='utf-8') as f:
f.write(json.dumps(data, ensure_ascii=False) + '\n')
return
Expand Down Expand Up @@ -163,8 +146,8 @@ def load(path:str):
path (str): path to the file
"""
with open(path, 'r', encoding='utf-8') as f:
chatlog = json.loads(f.read())
return Chat(chatlog)
chat_log = json.loads(f.read())
return Chat(chat_log['chat_log'])

@staticmethod
def display_role_content(dic:dict, sep:Union[str, None]=None):
Expand Down
34 changes: 13 additions & 21 deletions chattool/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
from .chattool import Chat
import tqdm

def load_chats( checkpoint:str
, withid:bool=False):
def load_chats( checkpoint:str):
"""Load chats from a checkpoint file
Args:
checkpoint (str): path to the checkpoint file
withid (bool, optional): Deprecated. It is not needed anymore. Defaults to False.
Returns:
list: chats
Expand All @@ -23,25 +21,20 @@ def load_chats( checkpoint:str
txts = f.read().strip().split('\n')
## empty file
if len(txts) == 1 and txts[0] == '': return []

# get the chatlogs
logs = [json.loads(txt) for txt in txts]
## chatlogs with chatid
if 'chatid' in logs[0]:
chat_size, chatlogs = 1, [None]
for log in logs:
idx = log['chatid']
if idx >= chat_size: # extend chatlogs
chatlogs.extend([None] * (idx - chat_size + 1))
chat_size = idx + 1
chatlogs[idx] = log['chatlog']
# check if there are missing chatlogs
if None in chatlogs:
warnings.warn(f"checkpoint file {checkpoint} has unfinished chats")
else: ## logs without chatid
chatlogs = logs
chat_size, chatlogs = 1, [None]
for log in logs:
idx = log['index']
if idx >= chat_size: # extend chatlogs
chatlogs.extend([None] * (idx - chat_size + 1))
chat_size = idx + 1
chatlogs[idx] = log['chat_log']
# check if there are missing chatlogs
if None in chatlogs:
warnings.warn(f"checkpoint file {checkpoint} has unfinished chats")
# return Chat class
return [Chat(chatlog) if chatlog is not None else None for chatlog in chatlogs]
return [Chat(chat_log) if chat_log is not None else None for chat_log in chatlogs]

def process_chats( data:List[Any]
, data2chat:Callable[[Any], Chat]
Expand All @@ -68,13 +61,12 @@ def process_chats( data:List[Any]
if len(chats) > len(data):
warnings.warn(f"checkpoint file {checkpoint} has more chats than the data to be processed")
return chats[:len(data)]

chats.extend([None] * (len(data) - len(chats)))
## process chats
tq = tqdm.tqdm if not isjupyter else tqdm.notebook.tqdm
for i in tq(range(len(data))):
if chats[i] is not None: continue
chat = data2chat(data[i])
chat.save(checkpoint, mode='a')
chat.save(checkpoint, mode='a', index=i)
chats[i] = chat
return chats
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '2.6.2'
VERSION = '3.0.0'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down Expand Up @@ -35,7 +35,7 @@
},
install_requires=requirements,
license="MIT license",
long_description=readme,
# long_description=readme,
long_description_content_type='text/markdown',
include_package_data=True,
keywords='chattool',
Expand Down
4 changes: 2 additions & 2 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def msg2log(msg):
chat.system("translate the words from English to Chinese")
chat.user(msg)
return chat.chat_log
def max_tokens(chatlog):
return Chat(chatlog).prompt_token()
def max_tokens(chat_log):
return Chat(chat_log).prompt_token()
async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3, max_tokens=max_tokens, msg2log=msg2log)

def test_normal_process():
Expand Down
14 changes: 7 additions & 7 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ def test_with_checkpoint():
# save chats without chatid
chat = Chat()
checkpath = testpath + "tmp.jsonl"
chat.save(checkpath, mode="w")
chat.save(checkpath, mode="w", index=0)
chat = Chat("hello!")
chat.save(checkpath) # append
chat.save(checkpath, index=1) # append
chat.assistant("你好, how can I assist you today?")
chat.save(checkpath) # append
chat.save(checkpath, index=2) # append
## load chats
chats = load_chats(checkpath)
chat_logs = [
Expand All @@ -23,13 +23,13 @@ def test_with_checkpoint():
# save chats with chatid
chat = Chat()
checkpath = testpath + "tmp_withid.jsonl"
chat.savewithid(checkpath, mode="w", chatid=0)
chat.save(checkpath, mode="w", index=0)
chat = Chat("hello!")
chat.savewithid(checkpath, chatid=3)
chat.save(checkpath, index=3)
chat.assistant("你好, how can I assist you today?")
chat.savewithid(checkpath, chatid=2)
chat.save(checkpath, index=2)
## load chats
chats = load_chats(checkpath, withid=True)
chats = load_chats(checkpath)
chat_logs = [
[],
None,
Expand Down
Empty file removed tests/testfiles/.keep
Empty file.

0 comments on commit e33f776

Please sign in to comment.