Skip to content

Commit

Permalink
Merge pull request #685 from euxx/feat/xinference
Browse files Browse the repository at this point in the history
feat: add Xinference LLM support
  • Loading branch information
zainhoda authored Oct 25, 2024
2 parents ac1a841 + 2bd2717 commit 0ee8185
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -56,3 +56,4 @@ azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fast
pgvector = ["langchain-postgres>=0.0.12"]
faiss-cpu = ["faiss-cpu"]
faiss-gpu = ["faiss-gpu"]
xinference-client = ["xinference-client"]
1 change: 1 addition & 0 deletions src/vanna/xinference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .xinference import Xinference
53 changes: 53 additions & 0 deletions src/vanna/xinference/xinference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from xinference_client.client.restful.restful_client import (
Client,
RESTfulChatModelHandle,
)

from ..base import VannaBase


class Xinference(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)

if not config or "base_url" not in config:
raise ValueError("config must contain at least Xinference base_url")

base_url = config["base_url"]
api_key = config.get("api_key", "not empty")
self.xinference_client = Client(base_url=base_url, api_key=api_key)

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
raise Exception("Prompt is None")

if len(prompt) == 0:
raise Exception("Prompt is empty")

num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4

model_uid = kwargs.get("model_uid") or self.config.get("model_uid", None)
if model_uid is None:
raise ValueError("model_uid is required")

xinference_model = self.xinference_client.get_model(model_uid)
if isinstance(xinference_model, RESTfulChatModelHandle):
print(
f"Using model_uid {model_uid} for {num_tokens} tokens (approx)"
)

response = xinference_model.chat(prompt)
return response["choices"][0]["message"]["content"]
else:
raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle")
2 changes: 2 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_regular_imports():
from vanna.remote import VannaDefault
from vanna.vannadb.vannadb_vector import VannaDB_VectorStore
from vanna.weaviate.weaviate_vector import WeaviateDatabase
from vanna.xinference.xinference import Xinference
from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings

Expand All @@ -52,4 +53,5 @@ def test_shortcut_imports():
from vanna.vannadb import VannaDB_VectorStore
from vanna.vllm import Vllm
from vanna.weaviate import WeaviateDatabase
from vanna.xinference import Xinference
from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings

0 comments on commit 0ee8185

Please sign in to comment.