Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加Mistral的支持 #1784

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bot/bot_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,8 @@ def create_bot(bot_type):
from bot.zhipuai.zhipuai_bot import ZHIPUAIBot
return ZHIPUAIBot()

elif bot_type == const.MISTRAL:
from bot.mistral.mistralai_bot import MistralAIBot
return MistralAIBot()

raise RuntimeError
112 changes: 112 additions & 0 deletions bot/mistral/mistralai_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# encoding:utf-8

import time

from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

from bot.bot import Bot
from bot.mistral.mistralai_session import MistralAISession
from bot.session_manager import SessionManager
from bridge.context import ContextType
from bridge.reply import Reply, ReplyType
from common.log import logger
from config import conf

user_session = dict()


# OpenAI对话模型API (可用)
class MistralAIBot(Bot):
def __init__(self):
super().__init__()
api_key = conf().get("mistralai_api_key")
self.client = MistralClient(api_key=api_key)
self.system_prompt = conf().get("character_desc", "")
self.sessions = SessionManager(MistralAISession, model=conf().get("model") or "mistral-large-latest")
self.model = conf().get("model") or "mistral-large-latest" # 对话模型的名称
self.temperature = conf().get("temperature", 0.7) # 值在[0,1]之间,越大表示回复越具有不确定性
self.top_p = conf().get("top_p", 1)
self.safe_prompt = True
logger.info("[MISTRAL_AI] Create finish.")

def reply(self, query, context=None):
# acquire reply content
if context and context.type:
if context.type == ContextType.TEXT:
logger.info("[MISTRAL_AI] query={}".format(query))
session_id = context["session_id"]
reply = None
if query == "#清除记忆":
self.sessions.clear_session(session_id)
reply = Reply(ReplyType.INFO, "记忆已清除")
elif query == "#清除所有":
self.sessions.clear_all_session()
reply = Reply(ReplyType.INFO, "所有人记忆已清除")
else:
session = self.sessions.session_query(query, session_id)
result = self.reply_text(session)
total_tokens, completion_tokens, reply_content = (
result["total_tokens"],
result["completion_tokens"],
result["content"],
)
logger.debug(
"[MISTRAL_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format(str(session), session_id, reply_content, completion_tokens)
)

if total_tokens == 0:
reply = Reply(ReplyType.ERROR, reply_content)
else:
self.sessions.session_reply(reply_content, session_id, total_tokens)
reply = Reply(ReplyType.TEXT, reply_content)
return reply
else:
logger.info("[MISTRAL_AI] context={}".format(context))

def reply_text(self, session: MistralAISession):
try:
messages = self._convert_to_mistral_messages(self._filter_messages(session.messages))
response = self.client.chat(messages, temperature=self.temperature, model=self.model,
top_p=self.top_p, safe_prompt=self.safe_prompt)
res_content = response.choices[0].message.content
total_tokens = response.usage.total_tokens
completion_tokens = response.usage.completion_tokens
logger.info("[MISTRAL_AI] reply={}".format(res_content))
return {
"total_tokens": total_tokens,
"completion_tokens": completion_tokens,
"content": res_content,
}
except Exception as e:
result = {"total_tokens": 0, "completion_tokens": 0, "content": "我刚刚开小差了,请稍后再试一下"}
logger.warn("[MISTRAL_AI] Exception: {}".format(e))
return result

def _convert_to_mistral_messages(self, messages: list):
res = []
res.append(ChatMessage(role="system", content=self.system_prompt))
for msg in messages:
if msg.get("role") == "user":
role = "user"
elif msg.get("role") == "assistant":
role = "model"
else:
continue
res.append(
ChatMessage(role=role, content=msg.get("content")))
return res

def _filter_messages(self, messages: list):
res = []
turn = "user"
for i in range(len(messages) - 1, -1, -1):
message = messages[i]
if message.get("role") != turn:
continue
res.insert(0, message)
if turn == "user":
turn = "assistant"
elif turn == "assistant":
turn = "user"
return res
76 changes: 76 additions & 0 deletions bot/mistral/mistralai_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from bot.session_manager import Session
from common.log import logger


class MistralAISession(Session):
def __init__(self, session_id, system_prompt=None, model="text-davinci-003"):
super().__init__(session_id, system_prompt)
self.model = model
self.reset()

def __str__(self):
# 构造对话模型的输入
"""
e.g. Q: xxx
A: xxx
Q: xxx
"""
prompt = ""
for item in self.messages:
if item["role"] == "system":
prompt += item["content"] + "<|endoftext|>\n\n\n"
elif item["role"] == "user":
prompt += "Q: " + item["content"] + "\n"
elif item["role"] == "assistant":
prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n"

if len(self.messages) > 0 and self.messages[-1]["role"] == "user":
prompt += "A: "
return prompt

def discard_exceeding(self, max_tokens, cur_tokens=None):
precise = True
try:
cur_tokens = self.calc_tokens()
except Exception as e:
precise = False
if cur_tokens is None:
raise e
logger.debug("Exception when counting tokens precisely for query: {}".format(e))
while cur_tokens > max_tokens:
if len(self.messages) > 1:
self.messages.pop(0)
elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant":
self.messages.pop(0)
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = len(str(self))
break
elif len(self.messages) == 1 and self.messages[0]["role"] == "user":
logger.warn("user question exceed max_tokens. total_tokens={}".format(cur_tokens))
break
else:
logger.debug("max_tokens={}, total_tokens={}, len(conversation)={}".format(max_tokens, cur_tokens, len(self.messages)))
break
if precise:
cur_tokens = self.calc_tokens()
else:
cur_tokens = len(str(self))
return cur_tokens

def calc_tokens(self):
return num_tokens_from_string(str(self), self.model)


# refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
def num_tokens_from_string(string: str, model: str) -> int:
"""Returns the number of tokens in a text string."""
import tiktoken
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.warn("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(string, disallowed_special=()))
return num_tokens
3 changes: 3 additions & 0 deletions bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def __init__(self):
self.btype["chat"] = const.GEMINI
if model_type in [const.ZHIPU_AI]:
self.btype["chat"] = const.ZHIPU_AI
if model_type in [const.MODEL_MISTRAL_LARGE, const.MODEL_MISTRAL_MEDIUM, const.MODEL_MISTRAL_SMALL,
const.MODEL_MISTRAL_OPEN_7B, const.MODEL_MISTRAL_OPEN_8X7B]:
self.btype["chat"] = const.MISTRAL

if conf().get("use_linkai") and conf().get("linkai_api_key"):
self.btype["chat"] = const.LINKAI
Expand Down
11 changes: 8 additions & 3 deletions common/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
QWEN = "qwen"
GEMINI = "gemini"
ZHIPU_AI = "glm-4"

MISTRAL = "mistral"

# model
GPT35 = "gpt-3.5-turbo"
Expand All @@ -19,10 +19,15 @@
WHISPER_1 = "whisper-1"
TTS_1 = "tts-1"
TTS_1_HD = "tts-1-hd"
MODEL_MISTRAL_OPEN_7B = "open-mistral-7b"
MODEL_MISTRAL_OPEN_8X7B = "open-mixtral-8x7b"
MODEL_MISTRAL_SMALL = "mistral-small-latest"
MODEL_MISTRAL_MEDIUM = "mistral-medium-latest"
MODEL_MISTRAL_LARGE = "mistral-large-latest"

MODEL_LIST = ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "wenxin", "wenxin-4", "xunfei", "claude", "gpt-4-turbo",
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI]
"gpt-4-turbo-preview", "gpt-4-1106-preview", GPT4_TURBO_PREVIEW, QWEN, GEMINI, ZHIPU_AI, MISTRAL]

# channel
FEISHU = "feishu"
DINGTALK = "dingtalk"
DINGTALK = "dingtalk"
2 changes: 2 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
"qwen_node_id": "", # 流程编排模型用到的id,如果没有用到qwen_node_id,请务必保持为空字符串
# Google Gemini Api Key
"gemini_api_key": "",
# Mistral AI API Key
"mistralai_api_key": "",
# wework的通用配置
"wework_smart": True, # 配置wework是否使用已登录的企业微信,False为多开
# 语音设置
Expand Down
3 changes: 3 additions & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ dingtalk_stream

# zhipuai
zhipuai>=2.0.1

#mistralai
mistralai