diff --git a/pyproject.toml b/pyproject.toml index af25d8ca..9b9397f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] google = ["google-generativeai", "google-cloud-aiplatform"] -all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres"] +all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] @@ -56,3 +56,4 @@ azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fast pgvector = ["langchain-postgres>=0.0.12"] faiss-cpu = ["faiss-cpu"] faiss-gpu = ["faiss-gpu"] +xinference-client = ["xinference-client"] diff --git a/src/vanna/xinference/__init__.py b/src/vanna/xinference/__init__.py new file mode 100644 index 00000000..db693d09 --- /dev/null +++ b/src/vanna/xinference/__init__.py @@ -0,0 +1 @@ +from .xinference import Xinference diff --git a/src/vanna/xinference/xinference.py b/src/vanna/xinference/xinference.py new file mode 100644 index 00000000..13105b57 --- /dev/null +++ b/src/vanna/xinference/xinference.py @@ -0,0 +1,53 @@ +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulChatModelHandle, +) + +from ..base import VannaBase + + +class Xinference(VannaBase): + def __init__(self, config=None): + VannaBase.__init__(self, config=config) + + if not config or "base_url" not in config: + raise ValueError("config must contain at least Xinference base_url") + + base_url = config["base_url"] + api_key = config.get("api_key", "not empty") + self.xinference_client = Client(base_url=base_url, api_key=api_key) + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + model_uid = kwargs.get("model_uid") or self.config.get("model_uid", None) + if model_uid is None: + raise ValueError("model_uid is required") + + xinference_model = self.xinference_client.get_model(model_uid) + if isinstance(xinference_model, RESTfulChatModelHandle): + print( + f"Using model_uid {model_uid} for {num_tokens} tokens (approx)" + ) + + response = xinference_model.chat(prompt) + return response["choices"][0]["message"]["content"] + else: + raise NotImplementedError(f"Xinference model handle type {type(xinference_model)} is not supported, required RESTfulChatModelHandle") diff --git a/tests/test_imports.py b/tests/test_imports.py index dc682069..2c0fff5d 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -28,6 +28,7 @@ def test_regular_imports(): from vanna.remote import VannaDefault from vanna.vannadb.vannadb_vector import VannaDB_VectorStore from vanna.weaviate.weaviate_vector import WeaviateDatabase + from vanna.xinference.xinference import Xinference from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings @@ -52,4 +53,5 @@ def test_shortcut_imports(): from vanna.vannadb import VannaDB_VectorStore from vanna.vllm import Vllm from vanna.weaviate import WeaviateDatabase + from vanna.xinference import Xinference from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings