diff --git a/src/vanna/qianwen/QianwenAI_chat.py b/src/vanna/qianwen/QianwenAI_chat.py new file mode 100644 index 00000000..6966882d --- /dev/null +++ b/src/vanna/qianwen/QianwenAI_chat.py @@ -0,0 +1,133 @@ +import os + +from openai import OpenAI + +from ..base import VannaBase + + +class QianWenAI_Chat(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + # default parameters - can be overrided using config + self.temperature = 0.7 + + if "temperature" in config: + self.temperature = config["temperature"] + + if "api_type" in config: + raise Exception( + "Passing api_type is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_base" in config: + raise Exception( + "Passing api_base is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_version" in config: + raise Exception( + "Passing api_version is now deprecated. Please pass an OpenAI client instead." + ) + + if client is not None: + self.client = client + return + + if config is None and client is None: + self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + return + + if "api_key" in config: + if "base_url" not in config: + self.client = OpenAI(api_key=config["api_key"], + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1") + else: + self.client = OpenAI(api_key=config["api_key"], + base_url=config["base_url"]) + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + # Count the number of tokens in the message log + # Use 4 as an approximation for the number of characters per token + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + if kwargs.get("model", None) is not None: + model = kwargs.get("model", None) + print( + f"Using model {model} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=model, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif kwargs.get("engine", None) is not None: + engine = kwargs.get("engine", None) + print( + f"Using model {engine} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=engine, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "engine" in self.config: + print( + f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=self.config["engine"], + messages=prompt, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "model" in self.config: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=self.config["model"], + messages=prompt, + stop=None, + temperature=self.temperature, + ) + else: + if num_tokens > 3500: + model = "qwen-long" + else: + model = "qwen-plus" + + print(f"Using model {model} for {num_tokens} tokens (approx)") + response = self.client.chat.completions.create( + model=model, + messages=prompt, + stop=None, + temperature=self.temperature, + ) + + # Find the first response from the chatbot that has text in it (some responses may not have text) + for choice in response.choices: + if "text" in choice: + return choice.text + + # If no response with text is found, return the first response's content (which may be empty) + return response.choices[0].message.content diff --git a/src/vanna/qianwen/QianwenAI_embeddings.py b/src/vanna/qianwen/QianwenAI_embeddings.py new file mode 100644 index 00000000..17047281 --- /dev/null +++ b/src/vanna/qianwen/QianwenAI_embeddings.py @@ -0,0 +1,46 @@ +from openai import OpenAI + +from ..base import VannaBase + + +class QianWenAI_Embeddings(VannaBase): + def __init__(self, client=None, config=None): + VannaBase.__init__(self, config=config) + + if client is not None: + self.client = client + return + + if self.client is not None: + return + + self.client = OpenAI() + + if config is None: + return + + if "api_type" in config: + self.client.api_type = config["api_type"] + + if "api_base" in config: + self.client.api_base = config["api_base"] + + if "api_version" in config: + self.client.api_version = config["api_version"] + + if "api_key" in config: + self.client.api_key = config["api_key"] + + def generate_embedding(self, data: str, **kwargs) -> list[float]: + if self.config is not None and "engine" in self.config: + embedding = self.client.embeddings.create( + engine=self.config["engine"], + input=data, + ) + else: + embedding = self.client.embeddings.create( + model="bge-large-zh", + input=data, + ) + + return embedding.get("data")[0]["embedding"] diff --git a/src/vanna/qianwen/__init__.py b/src/vanna/qianwen/__init__.py new file mode 100644 index 00000000..acc1bae5 --- /dev/null +++ b/src/vanna/qianwen/__init__.py @@ -0,0 +1,2 @@ +from .QianwenAI_chat import QianWenAI_Chat +from .QianwenAI_embeddings import QianWenAI_Embeddings