Skip to content

Commit

Permalink
Merge pull request #590 from dusens/main
Browse files Browse the repository at this point in the history
<feat><QianwenAI>add QianwenAI call function
  • Loading branch information
zainhoda authored Aug 21, 2024
2 parents f854616 + 2962cb5 commit b32ad49
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 0 deletions.
133 changes: 133 additions & 0 deletions src/vanna/qianwen/QianwenAI_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os

from openai import OpenAI

from ..base import VannaBase


class QianWenAI_Chat(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)

# default parameters - can be overrided using config
self.temperature = 0.7

if "temperature" in config:
self.temperature = config["temperature"]

if "api_type" in config:
raise Exception(
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
)

if "api_base" in config:
raise Exception(
"Passing api_base is now deprecated. Please pass an OpenAI client instead."
)

if "api_version" in config:
raise Exception(
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
)

if client is not None:
self.client = client
return

if config is None and client is None:
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return

if "api_key" in config:
if "base_url" not in config:
self.client = OpenAI(api_key=config["api_key"],
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
self.client = OpenAI(api_key=config["api_key"],
base_url=config["base_url"])

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")

if len(prompt) == 0:
raise Exception("Prompt is empty")

# Count the number of tokens in the message log
# Use 4 as an approximation for the number of characters per token
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4

if kwargs.get("model", None) is not None:
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=model,
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=engine,
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
engine=self.config["engine"],
messages=prompt,
stop=None,
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
stop=None,
temperature=self.temperature,
)
else:
if num_tokens > 3500:
model = "qwen-long"
else:
model = "qwen-plus"

print(f"Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
messages=prompt,
stop=None,
temperature=self.temperature,
)

# Find the first response from the chatbot that has text in it (some responses may not have text)
for choice in response.choices:
if "text" in choice:
return choice.text

# If no response with text is found, return the first response's content (which may be empty)
return response.choices[0].message.content
46 changes: 46 additions & 0 deletions src/vanna/qianwen/QianwenAI_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from openai import OpenAI

from ..base import VannaBase


class QianWenAI_Embeddings(VannaBase):
def __init__(self, client=None, config=None):
VannaBase.__init__(self, config=config)

if client is not None:
self.client = client
return

if self.client is not None:
return

self.client = OpenAI()

if config is None:
return

if "api_type" in config:
self.client.api_type = config["api_type"]

if "api_base" in config:
self.client.api_base = config["api_base"]

if "api_version" in config:
self.client.api_version = config["api_version"]

if "api_key" in config:
self.client.api_key = config["api_key"]

def generate_embedding(self, data: str, **kwargs) -> list[float]:
if self.config is not None and "engine" in self.config:
embedding = self.client.embeddings.create(
engine=self.config["engine"],
input=data,
)
else:
embedding = self.client.embeddings.create(
model="bge-large-zh",
input=data,
)

return embedding.get("data")[0]["embedding"]
2 changes: 2 additions & 0 deletions src/vanna/qianwen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .QianwenAI_chat import QianWenAI_Chat
from .QianwenAI_embeddings import QianWenAI_Embeddings

0 comments on commit b32ad49

Please sign in to comment.