Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrated Milvus with MetaGPT #1457

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions metagpt/document_store/milvus_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from pymilvus import MilvusClient, DataType

from metagpt.document_store.base_store import BaseStore

@dataclass
class MilvusConnection:
"""
Args:
uri: milvus url
token: milvus token
"""

uri: str = None
token: str = None


class MilvusStore(BaseStore):
def __init__(self, connect: MilvusConnection):
if not connect.uri:
raise Exception("please check MilvusConnection, uri must be set.")
self.client = MilvusClient(
uri=connect.uri,
token=connect.token
)

def create_collection(
self,
collection_name: str,
dim: int,
enable_dynamic_schema: bool = True
):
if self.client.has_collection(collection_name=collection_name):
self.client.drop_collection(collection_name=collection_name)

schema = self.client.create_schema(
auto_id=False,
enable_dynamic_field=False,
)
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)

index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="AUTOINDEX",
metric_type="COSINE"
)

self.client.create_collection(
collection_name=collection_name,
schema=schema,
index_params=index_params,
enable_dynamic_schema=enable_dynamic_schema
)

@staticmethod
def build_filter(key, value) -> str:
if isinstance(value, str):
filter_expression = f'{key} == "{value}"'
else:
if isinstance(value, list):
filter_expression = f'{key} in {value}'
else:
filter_expression = f'{key} == {value}'

return filter_expression

def search(
self,
collection_name: str,
query: List[float],
filter: Dict[str, str | int | list[int]] = None,
limit: int = 10,
output_fields: Optional[List[str]] = None,
) -> List[dict]:
filter_expression = ''

for key, value in filter.items():
filter_expression += f'{self.build_filter(key, value)} and '
print(filter_expression)

res = self.client.search(
collection_name=collection_name,
data=[query],
filter=filter_expression,
limit=limit,
output_fields=output_fields,
)[0]

return res

def add(
self,
collection_name: str,
_ids: List[str],
vector: List[List[float]],
metadata: List[Dict[str, Any]]
):
data = dict()

for i, id in enumerate(_ids):
data['id'] = id
data['vector'] = vector[i]
data['metadata'] = metadata[i]

self.client.upsert(
collection_name=collection_name,
data=data
)

def delete(
self,
collection_name: str,
_ids: List[str]
):
self.client.delete(
collection_name=collection_name,
ids=_ids
)

def write(self, *args, **kwargs):
pass
8 changes: 8 additions & 0 deletions metagpt/rag/factories/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
Expand All @@ -17,6 +18,7 @@
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
MilvusIndexConfig,
)


Expand All @@ -28,6 +30,7 @@ def __init__(self):
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
MilvusIndexConfig: self._create_milvus
}
super().__init__(creators)

Expand All @@ -46,6 +49,11 @@ def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:

return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)

def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token)

return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
Expand Down
15 changes: 15 additions & 0 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.vector_stores.milvus import MilvusVectorStore

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
Expand All @@ -20,13 +21,15 @@
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)


Expand Down Expand Up @@ -56,6 +59,7 @@ def __init__(self):
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
MilvusRetrieverConfig: self._create_milvus_retriever,
}
super().__init__(creators)

Expand All @@ -76,6 +80,11 @@ def _create_default(self, **kwargs) -> RAGRetriever:

return index.as_retriever()

def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
config.index = self._build_milvus_index(config, **kwargs)

return MilvusRetriever(**config.model_dump())

def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
config.index = self._build_faiss_index(config, **kwargs)

Expand Down Expand Up @@ -128,6 +137,12 @@ def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> Vector

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions)

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
Expand Down
17 changes: 17 additions & 0 deletions metagpt/rag/retrievers/milvus_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Milvus retriever."""

from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode


class MilvusRetriever(VectorIndexRetriever):
"""Milvus retriever."""

def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)

def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.

Milvus automatically saves, so there is no need to implement."""
42 changes: 41 additions & 1 deletion metagpt/rag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator

from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
Expand Down Expand Up @@ -62,6 +62,36 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
_no_embedding: bool = PrivateAttr(default=True)


class MilvusRetrieverConfig(IndexRetrieverConfig):
"""Config for Milvus-based retrievers."""

uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
token: str = Field(default=None, description="The token for Milvus")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")

_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}

@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
config.embedding.api_type, 1536
)
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
logger.warning(
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
)

return self


class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""

Expand Down Expand Up @@ -169,6 +199,16 @@ class ChromaIndexConfig(VectorIndexConfig):
default=None, description="Optional metadata to associate with the collection"
)

class MilvusIndexConfig(VectorIndexConfig):
"""Config for milvus-based index."""

collection_name: str = Field(default="metagpt", description="The name of the collection.")
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
token: Optional[str] = Field(default=None, description="The token of the index.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)


class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ gymnasium==0.29.1
boto3~=1.34.69
spark_ai_python~=0.3.30
agentops
pymilvus==2.4.5
48 changes: 48 additions & 0 deletions tests/metagpt/document_store/test_milvus_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import random
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore

seed_value = 42
random.seed(seed_value)

vectors = [[random.random() for _ in range(8)] for _ in range(10)]
ids = [f"doc_{i}" for i in range(10)]
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)]


def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"


def test_milvus_store():
milvus_connection = MilvusConnection(uri="./milvus_local.db")
milvus_store = MilvusStore(milvus_connection)

collection_name = "TestCollection"
milvus_store.create_collection(collection_name, dim=8)

milvus_store.add(collection_name, ids, vectors, metadata)

search_results = milvus_store.search(collection_name, query=[1.0] * 8)
assert len(search_results) > 0
first_result = search_results[0]
assert first_result["id"] == "doc_0"

search_results_with_filter = milvus_store.search(
collection_name,
query=[1.0] * 8,
filter={"rand_number": 1}
)
assert len(search_results_with_filter) > 0
assert search_results_with_filter[0]["id"] == "doc_1"

milvus_store.delete(collection_name, _ids=["doc_0"])
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1)
assert deleted_results[0]["id"] != "doc_0"

milvus_store.client.drop_collection(collection_name)
16 changes: 15 additions & 1 deletion tests/metagpt/rag/factories/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
FAISSIndexConfig, MilvusIndexConfig,
)


Expand All @@ -20,6 +20,10 @@ def setup(self):
def faiss_config(self):
return FAISSIndexConfig(persist_path="")

@pytest.fixture
def milvus_config(self):
return MilvusIndexConfig(uri="", collection_name="")

@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
Expand Down Expand Up @@ -65,6 +69,16 @@ def test_create_bm25_index(
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)

def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding):
# Mock
mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore")

# Exec
self.index_factory.get_index(milvus_config, embed_model=mock_embedding)

# Assert
mock_milvus_store.assert_called_once()

def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
Expand Down
Loading
Loading