diff --git a/chattool/__init__.py b/chattool/__init__.py index ea4a91d..9f5efb7 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '2.6.2' +__version__ = '3.0.0' import os, sys, requests from .chattool import Chat, Resp diff --git a/chattool/asynctool.py b/chattool/asynctool.py index 8c801f2..c474470 100644 --- a/chattool/asynctool.py +++ b/chattool/asynctool.py @@ -83,11 +83,11 @@ async def async_process_msgs( chatlogs:List[List[Dict]] sem = asyncio.Semaphore(ncoroutines) locker = asyncio.Lock() - async def chat_complete(ind, locker, chatlog, chkpoint, **options): - payload = {"messages": chatlog} + async def chat_complete(ind, locker, chat_log, chkpoint, **options): + payload = {"messages": chat_log} payload.update(options) if max_tokens is not None: - payload['max_tokens'] = max_tokens(chatlog) + payload['max_tokens'] = max_tokens(chat_log) data = json.dumps(payload) resp = await async_post( session=session , sem=sem @@ -99,22 +99,22 @@ async def chat_complete(ind, locker, chatlog, chkpoint, **options): , timeout=timeout) ## saving files if resp is None: return 0, 0 - chatlog.append(resp.message) - chat = Chat(chatlog) + chat_log.append(resp.message) + chat = Chat(chat_log) async with locker: # locker | not necessary for normal IO - chat.savewithid(chkpoint, chatid=ind) + chat.save(chkpoint, index=ind) return ind, resp.cost() async with sem, aiohttp.ClientSession() as session: tasks = [] - for ind, chatlog in enumerate(chatlogs): + for ind, chat_log in enumerate(chatlogs): if chats[ind] is not None: # skip completed chats continue tasks.append( asyncio.create_task( chat_complete( ind=ind , locker=locker - , chatlog=chatlog + , chat_log=chat_log , chkpoint=chkpoint , **options))) try: # for mac or linux @@ -186,7 +186,7 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str] # run async process assert ncoroutines > 0, "ncoroutines must be greater than 0!" if isinstance(max_tokens, int): - max_tokens = lambda chatlog: max_tokens + max_tokens = lambda chat_log: max_tokens args = { "chatlogs": chatlogs, "chkpoint": chkpoint, diff --git a/chattool/chattool.py b/chattool/chattool.py index da14c88..faeef3f 100644 --- a/chattool/chattool.py +++ b/chattool/chattool.py @@ -102,7 +102,7 @@ def deepcopy(self): , name2func=self.name2func , base_url=self.base_url) - def save(self, path:str, mode:str='a'): + def save(self, path:str, mode:str='a', index:int=0): """ Save the chat log to a file. Each line is a json string. @@ -115,24 +115,7 @@ def save(self, path:str, mode:str='a'): pathname = os.path.dirname(path).strip() if pathname != '': os.makedirs(pathname, exist_ok=True) - with open(path, mode, encoding='utf-8') as f: - f.write(json.dumps(self.chat_log, ensure_ascii=False) + '\n') - return - - def savewithid(self, path:str, chatid:int, mode:str='a'): - """Save the chat log with chat id. Each line is a json string. - - Args: - path (str): path to the file - chatid (int): chat id - mode (str, optional): mode to open the file. Defaults to 'a'. - """ - assert mode in ['a', 'w'], "saving mode should be 'a' or 'w'" - # make path if not exists - pathname = os.path.dirname(path).strip() - if pathname != '': - os.makedirs(pathname, exist_ok=True) - data = {"chatid": chatid, "chatlog": self.chat_log} + data = {"index": index, "chat_log": self.chat_log} with open(path, mode, encoding='utf-8') as f: f.write(json.dumps(data, ensure_ascii=False) + '\n') return @@ -163,8 +146,8 @@ def load(path:str): path (str): path to the file """ with open(path, 'r', encoding='utf-8') as f: - chatlog = json.loads(f.read()) - return Chat(chatlog) + chat_log = json.loads(f.read()) + return Chat(chat_log['chat_log']) @staticmethod def display_role_content(dic:dict, sep:Union[str, None]=None): diff --git a/chattool/checkpoint.py b/chattool/checkpoint.py index 7ae80be..0e17eee 100644 --- a/chattool/checkpoint.py +++ b/chattool/checkpoint.py @@ -3,13 +3,11 @@ from .chattool import Chat import tqdm -def load_chats( checkpoint:str - , withid:bool=False): +def load_chats( checkpoint:str): """Load chats from a checkpoint file Args: checkpoint (str): path to the checkpoint file - withid (bool, optional): Deprecated. It is not needed anymore. Defaults to False. Returns: list: chats @@ -23,25 +21,20 @@ def load_chats( checkpoint:str txts = f.read().strip().split('\n') ## empty file if len(txts) == 1 and txts[0] == '': return [] - # get the chatlogs logs = [json.loads(txt) for txt in txts] - ## chatlogs with chatid - if 'chatid' in logs[0]: - chat_size, chatlogs = 1, [None] - for log in logs: - idx = log['chatid'] - if idx >= chat_size: # extend chatlogs - chatlogs.extend([None] * (idx - chat_size + 1)) - chat_size = idx + 1 - chatlogs[idx] = log['chatlog'] - # check if there are missing chatlogs - if None in chatlogs: - warnings.warn(f"checkpoint file {checkpoint} has unfinished chats") - else: ## logs without chatid - chatlogs = logs + chat_size, chatlogs = 1, [None] + for log in logs: + idx = log['index'] + if idx >= chat_size: # extend chatlogs + chatlogs.extend([None] * (idx - chat_size + 1)) + chat_size = idx + 1 + chatlogs[idx] = log['chat_log'] + # check if there are missing chatlogs + if None in chatlogs: + warnings.warn(f"checkpoint file {checkpoint} has unfinished chats") # return Chat class - return [Chat(chatlog) if chatlog is not None else None for chatlog in chatlogs] + return [Chat(chat_log) if chat_log is not None else None for chat_log in chatlogs] def process_chats( data:List[Any] , data2chat:Callable[[Any], Chat] @@ -68,13 +61,12 @@ def process_chats( data:List[Any] if len(chats) > len(data): warnings.warn(f"checkpoint file {checkpoint} has more chats than the data to be processed") return chats[:len(data)] - chats.extend([None] * (len(data) - len(chats))) ## process chats tq = tqdm.tqdm if not isjupyter else tqdm.notebook.tqdm for i in tq(range(len(data))): if chats[i] is not None: continue chat = data2chat(data[i]) - chat.save(checkpoint, mode='a') + chat.save(checkpoint, mode='a', index=i) chats[i] = chat return chats \ No newline at end of file diff --git a/setup.py b/setup.py index 3a35c4e..4640092 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '2.6.2' +VERSION = '3.0.0' requirements = [ 'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8', @@ -35,7 +35,7 @@ }, install_requires=requirements, license="MIT license", - long_description=readme, + # long_description=readme, long_description_content_type='text/markdown', include_package_data=True, keywords='chattool', diff --git a/tests/test_async.py b/tests/test_async.py index 8001351..b820bb1 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -66,8 +66,8 @@ def msg2log(msg): chat.system("translate the words from English to Chinese") chat.user(msg) return chat.chat_log - def max_tokens(chatlog): - return Chat(chatlog).prompt_token() + def max_tokens(chat_log): + return Chat(chat_log).prompt_token() async_chat_completion(words, chkpoint, clearfile=True, ncoroutines=3, max_tokens=max_tokens, msg2log=msg2log) def test_normal_process(): diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index 548cf3c..cd81821 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -6,11 +6,11 @@ def test_with_checkpoint(): # save chats without chatid chat = Chat() checkpath = testpath + "tmp.jsonl" - chat.save(checkpath, mode="w") + chat.save(checkpath, mode="w", index=0) chat = Chat("hello!") - chat.save(checkpath) # append + chat.save(checkpath, index=1) # append chat.assistant("你好, how can I assist you today?") - chat.save(checkpath) # append + chat.save(checkpath, index=2) # append ## load chats chats = load_chats(checkpath) chat_logs = [ @@ -23,13 +23,13 @@ def test_with_checkpoint(): # save chats with chatid chat = Chat() checkpath = testpath + "tmp_withid.jsonl" - chat.savewithid(checkpath, mode="w", chatid=0) + chat.save(checkpath, mode="w", index=0) chat = Chat("hello!") - chat.savewithid(checkpath, chatid=3) + chat.save(checkpath, index=3) chat.assistant("你好, how can I assist you today?") - chat.savewithid(checkpath, chatid=2) + chat.save(checkpath, index=2) ## load chats - chats = load_chats(checkpath, withid=True) + chats = load_chats(checkpath) chat_logs = [ [], None, diff --git a/tests/testfiles/.keep b/tests/testfiles/.keep deleted file mode 100644 index e69de29..0000000