From 723c85132d287747940319460ed155a4972fba46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Alejandro=20Bordo=20Garc=C3=ADa?= Date: Mon, 20 Nov 2023 15:42:11 +0100 Subject: [PATCH] ResearchContextRepository: new_vector_store #89 - Declared in core, created DTO - Implemented in SQLAResearchContextRepository - Tested --- .../dto/research_context_repository_dto.py | 13 +- .../secondary/research_context_repository.py | 25 ++ .../sqla/sqla_research_context_repository.py | 150 ++++++++++- tests/conftest.py | 83 +++++++ ...r_store_for_research_context_repository.py | 233 ++++++++++++++++++ 5 files changed, 501 insertions(+), 3 deletions(-) create mode 100644 tests/repositories/sqla_research_context/test_new_vector_store_for_research_context_repository.py diff --git a/lib/core/dto/research_context_repository_dto.py b/lib/core/dto/research_context_repository_dto.py index 9680f40..e116006 100644 --- a/lib/core/dto/research_context_repository_dto.py +++ b/lib/core/dto/research_context_repository_dto.py @@ -1,5 +1,5 @@ from typing import List -from lib.core.entity.models import Conversation, ResearchContext, User +from lib.core.entity.models import Conversation, ResearchContext, User, VectorStore from lib.core.sdk.dto import BaseDTO @@ -13,6 +13,17 @@ class GetResearchContextDTO(BaseDTO[ResearchContext]): data: ResearchContext | None = None +class NewVectorStoreDTO(BaseDTO[VectorStore]): + """ + A DTO for creating a new vector store + + @param data: The vector store + @type data: VectorStore | None + """ + + vector_store_id: int | None = None + + class GetResearchContextUserDTO(BaseDTO[User]): """ A DTO for getting the user of a research context diff --git a/lib/core/ports/secondary/research_context_repository.py b/lib/core/ports/secondary/research_context_repository.py index 7e29ad8..21c243b 100644 --- a/lib/core/ports/secondary/research_context_repository.py +++ b/lib/core/ports/secondary/research_context_repository.py @@ -6,7 +6,9 @@ GetResearchContextUserDTO, ListResearchContextConversationsDTO, NewResearchContextConversationDTO, + NewVectorStoreDTO, ) +from lib.core.entity.models import ProtocolEnum class ResearchContextRepositoryOutputPort(ABC): @@ -36,6 +38,29 @@ def get_research_context(self, research_context_id: int) -> GetResearchContextDT """ raise NotImplementedError + @abstractmethod + def new_vector_store( + self, + research_context_id: int, + vector_store_lfn: str, + vector_store_name: str, + vector_store_protocol: ProtocolEnum, + embedding_model_id: int, + ) -> NewVectorStoreDTO: + """ + Creates a new vector store in the research context. + + @param research_context_id: The ID of the research context that is tied to the vector store. + @type research_context_id: int + @param vector_store_lfn: The LFN of the vector store. + @type vector_store_lfn: str + @param embedding_model_id: The ID of the embedding model to use for the vector store. + @type embedding_model_id: int + @return: A DTO containing the result of the operation. + @rtype: NewVectorStoreDTO + """ + raise NotImplementedError + @abstractmethod def get_research_context_user(self, research_context_id: int) -> GetResearchContextUserDTO: """ diff --git a/lib/infrastructure/repository/sqla/sqla_research_context_repository.py b/lib/infrastructure/repository/sqla/sqla_research_context_repository.py index f5ca284..1a3be26 100644 --- a/lib/infrastructure/repository/sqla/sqla_research_context_repository.py +++ b/lib/infrastructure/repository/sqla/sqla_research_context_repository.py @@ -4,13 +4,20 @@ GetResearchContextUserDTO, ListResearchContextConversationsDTO, NewResearchContextConversationDTO, + NewVectorStoreDTO, ) -from lib.core.entity.models import Conversation, ResearchContext +from lib.core.entity.models import Conversation, ProtocolEnum, ResearchContext from lib.core.ports.secondary.research_context_repository import ResearchContextRepositoryOutputPort from lib.infrastructure.repository.sqla.database import TDatabaseFactory from sqlalchemy.orm import Session -from lib.infrastructure.repository.sqla.models import SQLAConversation, SQLAResearchContext, SQLAUser +from lib.infrastructure.repository.sqla.models import ( + SQLAConversation, + SQLAEmbeddingModel, + SQLAResearchContext, + SQLAUser, + SQLAVectorStore, +) from lib.infrastructure.repository.sqla.utils import ( convert_sqla_conversation_to_core_conversation, convert_sqla_research_context_to_core_research_context, @@ -69,6 +76,145 @@ def get_research_context(self, research_context_id: int) -> GetResearchContextDT return GetResearchContextDTO(status=True, data=core_research_context) + def new_vector_store( + self, + research_context_id: int, + vector_store_lfn: str, + vector_store_name: str, + vector_store_protocol: ProtocolEnum, + embedding_model_id: int, + ) -> NewVectorStoreDTO: + """ + Creates a new vector store in the research context. + + @param research_context_id: The ID of the research context that is tied to the vector store. + @type research_context_id: int + @param vector_store_lfn: The LFN of the vector store. + @type vector_store_lfn: str + @param embedding_model_id: The ID of the embedding model to use for the vector store. + @type embedding_model_id: int + @return: A DTO containing the result of the operation. + @rtype: NewVectorStoreDTO + """ + + if research_context_id is None: + self.logger.error(f"Research Context ID cannot be None") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage="Research Context ID cannot be None", + errorName="Research Context ID not provided", + errorType="ResearchContextIdNotProvided", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + + if vector_store_lfn is None: + self.logger.error(f"Vector store LFN cannot be None") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage="Vector Store LFN cannot be None", + errorName="Vector Store LFN not provided", + errorType="VectorStoreLFNNotProvided", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + + if vector_store_name is None: + self.logger.error(f"Vector store name cannot be None") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage="Vector Store name cannot be None", + errorName="Vector Store name not provided", + errorType="VectorStoreNameNotProvided", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + + if vector_store_protocol is None: + self.logger.error(f"Vector store protocol cannot be None") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage="Vector Store protocol cannot be None", + errorName="Vector Store protocol not provided", + errorType="VectorStoreProtocolNotProvided", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + + if embedding_model_id is None: + self.logger.error(f"Embedding model ID cannot be None") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage="Embedding Model ID cannot be None", + errorName="Embedding Model ID not provided", + errorType="EmbeddingModelIdNotProvided", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + + queried_sqla_research_context: SQLAResearchContext | None = self.session.get( + SQLAResearchContext, research_context_id + ) + + if queried_sqla_research_context is None: + self.logger.error(f"Research context with ID {research_context_id} not found.") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage=f"Research Context with ID {research_context_id} not found in the database.", + errorName="Research Context not found", + errorType="ResearchContextNotFound", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + + queried_sqla_embedding_model: SQLAEmbeddingModel | None = self.session.get( + SQLAEmbeddingModel, embedding_model_id + ) + + if queried_sqla_embedding_model is None: + self.logger.error(f"Embedding model with ID {embedding_model_id} not found.") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage=f"Embedding Model with ID {embedding_model_id} not found in the database.", + errorName="Embedding Model not found", + errorType="EmbeddingModelNotFound", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + + sqla_new_vector_store: SQLAVectorStore = SQLAVectorStore( + name=vector_store_name, + lfn=vector_store_lfn, + protocol=vector_store_protocol, + embedding_model_id=embedding_model_id, + research_context_id=research_context_id, + ) + + try: + sqla_new_vector_store.save(session=self.session) + self.session.commit() + + return NewVectorStoreDTO(status=True, vector_store_id=sqla_new_vector_store.id) + + except Exception as e: + self.logger.error(f"Error while creating new vector store: {e}") + errorDTO = NewVectorStoreDTO( + status=False, + errorCode=-1, + errorMessage=f"Error while creating new vector store: {e}", + errorName="Error while creating new vector store", + errorType="ErrorWhileCreatingNewVectorStore", + ) + self.logger.error(f"{errorDTO}") + return errorDTO + def get_research_context_user(self, research_context_id: int) -> GetResearchContextUserDTO: """ Gets the user of a research context. diff --git a/tests/conftest.py b/tests/conftest.py index f992610..0629c39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ SQLALLM, SQLACitation, SQLAConversation, + SQLAEmbeddingModel, SQLAKnowledgeSource, SQLAMessageBase, SQLAMessageQuery, @@ -30,6 +31,7 @@ SQLAResearchContext, SQLASourceData, SQLAUser, + SQLAVectorStore, ) from tests.fixtures.factory.sqla_model_factory import SQLATemporaryModelFactory @@ -328,6 +330,42 @@ def fake_user_with_conversation() -> SQLAUser: return user_with_conversation() +def create_lfn() -> str: + fake = Faker().unique + + protocols = [ + attr_name.__str__().lower() for attr_name in vars(ProtocolEnum) if not attr_name.__str__().startswith("_") + ] + knowledge_sources = [ + attr_name.__str__().lower() + for attr_name in vars(KnowledgeSourceEnum) + if not attr_name.__str__().startswith("_") + ] + + sd_protocol_choice: str = random.choice(protocols) + sd_protocol = ProtocolEnum(sd_protocol_choice) + sd_host = fake.domain_name() + sd_port = fake.port_number() + sd_minio_bucket = fake.name() + sd_tracer_id = random.randint(1, 1000000000) + sd_knowledge_source_choice: str = random.choice(knowledge_sources) + sd_job_id = random.randint(1, 1000000000) + + sd_filename = fake.file_name() + sd_name = sd_filename.split(".")[0] + sd_type = sd_filename.split(".")[1] + sd_relative_path = fake.file_path(depth=3).split(".")[0] + "/" + sd_filename + + sd_lfn = f"{sd_protocol_choice}://{sd_host}:{sd_port}/{sd_minio_bucket}/{sd_tracer_id}/{sd_knowledge_source_choice}/{sd_job_id}{sd_relative_path}" + + return sd_lfn + + +@pytest.fixture(scope="function") +def fake_lfn_list() -> List[str]: + return [create_lfn() for _ in range(10)] + + def source_data() -> SQLASourceData: fake = Faker().unique @@ -444,3 +482,48 @@ def llm() -> SQLALLM: @pytest.fixture(scope="function") def fake_llm() -> SQLALLM: return llm() + + +def embedding_model() -> SQLAEmbeddingModel: + fake = Faker().unique + + fake_name = fake.name() + + return SQLAEmbeddingModel( + name=fake_name, + ) + + +@pytest.fixture(scope="function") +def fake_embedding_model() -> SQLAEmbeddingModel: + return embedding_model() + + +def vector_store() -> SQLAVectorStore: + fake = Faker().unique + + lfn = create_lfn() + name = lfn.split("/")[-1] + protocol_str = lfn.split("://")[0] + protocol = ProtocolEnum(protocol_str) + + return SQLAVectorStore( + name=name, + lfn=lfn, + protocol=protocol, + ) + + +def embedding_model_with_vector_store() -> SQLAEmbeddingModel: + em = embedding_model() + + vector_store_list = [vector_store() for _ in range(10)] + + em.vector_stores = vector_store_list + + return em + + +@pytest.fixture(scope="function") +def fake_embedding_model_with_vector_store() -> SQLAEmbeddingModel: + return embedding_model_with_vector_store() diff --git a/tests/repositories/sqla_research_context/test_new_vector_store_for_research_context_repository.py b/tests/repositories/sqla_research_context/test_new_vector_store_for_research_context_repository.py new file mode 100644 index 0000000..1d4f6df --- /dev/null +++ b/tests/repositories/sqla_research_context/test_new_vector_store_for_research_context_repository.py @@ -0,0 +1,233 @@ +import random +from typing import List +from lib.core.dto.research_context_repository_dto import NewVectorStoreDTO +from lib.core.entity.models import ProtocolEnum +from lib.infrastructure.config.containers import ApplicationContainer +from lib.infrastructure.repository.sqla.database import TDatabaseFactory +from lib.infrastructure.repository.sqla.models import ( + SQLALLM, + SQLAEmbeddingModel, + SQLAResearchContext, + SQLAUser, + SQLAVectorStore, +) + + +def test_create_new_vector_store_for_research_context( + app_initialization_container: ApplicationContainer, + db_session: TDatabaseFactory, + fake_user_with_conversation: SQLAUser, + fake_llm: SQLALLM, + fake_embedding_model_with_vector_store: SQLAEmbeddingModel, + fake_lfn_list: List[str], +) -> None: + research_context_repository = app_initialization_container.sqla_research_context_repository() + + user_with_conv = fake_user_with_conversation + llm = fake_llm + llm.research_contexts = user_with_conv.research_contexts + + embedding_model = fake_embedding_model_with_vector_store + + rand_int_2 = random.randint(0, len(fake_lfn_list) - 1) + new_vector_store_lfn = fake_lfn_list[rand_int_2] + new_vector_store_name = new_vector_store_lfn.split("/")[-1] + new_vector_store_protocol_str = new_vector_store_lfn.split("://")[0] + new_vector_store_protocol = ProtocolEnum(new_vector_store_protocol_str) + + with db_session() as session: + session.add(user_with_conv) + session.add(embedding_model) + session.commit() + rand_int_1 = random.randint(0, len(user_with_conv.research_contexts) - 1) + research_context = user_with_conv.research_contexts[rand_int_1] + research_context_id = research_context.id + embedding_model_id = embedding_model.id + + with db_session() as session: + new_vs_dto: NewVectorStoreDTO = research_context_repository.new_vector_store( + research_context_id=research_context_id, + embedding_model_id=embedding_model_id, + vector_store_name=new_vector_store_name, + vector_store_lfn=new_vector_store_lfn, + vector_store_protocol=new_vector_store_protocol, + ) + + assert new_vs_dto.status == True + assert new_vs_dto.vector_store_id is not None + + with db_session() as session: + queried_new_vector_store = session.get(SQLAVectorStore, new_vs_dto.vector_store_id) + + assert queried_new_vector_store is not None + assert queried_new_vector_store.name == new_vector_store_name + assert queried_new_vector_store.lfn == new_vector_store_lfn + assert queried_new_vector_store.protocol == new_vector_store_protocol + assert queried_new_vector_store.embedding_model_id == embedding_model_id + assert queried_new_vector_store.research_context_id == research_context_id + + +def test_error_new_vector_store_research_context_id_is_None( + app_initialization_container: ApplicationContainer, db_session: TDatabaseFactory +) -> None: + sqla_research_context_repository = app_initialization_container.sqla_research_context_repository() + + new_vs_dto: NewVectorStoreDTO = sqla_research_context_repository.new_vector_store( + research_context_id=None, # type: ignore + embedding_model_id=1, + vector_store_name="test", + vector_store_lfn="test", + vector_store_protocol=ProtocolEnum.S3, + ) + + assert new_vs_dto.status == False + assert new_vs_dto.errorCode == -1 + assert new_vs_dto.errorMessage == "Research Context ID cannot be None" + assert new_vs_dto.errorName == "Research Context ID not provided" + assert new_vs_dto.errorType == "ResearchContextIdNotProvided" + + +def test_error_new_vector_store_vector_store_lfn_is_none( + app_initialization_container: ApplicationContainer, db_session: TDatabaseFactory +) -> None: + sqla_research_context_repository = app_initialization_container.sqla_research_context_repository() + + new_vs_dto: NewVectorStoreDTO = sqla_research_context_repository.new_vector_store( + research_context_id=1, + embedding_model_id=1, + vector_store_name="test", + vector_store_lfn=None, # type: ignore + vector_store_protocol=ProtocolEnum.S3, + ) + + assert new_vs_dto.status == False + assert new_vs_dto.errorCode == -1 + assert new_vs_dto.errorMessage == "Vector Store LFN cannot be None" + assert new_vs_dto.errorName == "Vector Store LFN not provided" + assert new_vs_dto.errorType == "VectorStoreLFNNotProvided" + + +def test_error_new_vector_store_vector_store_name_is_none( + app_initialization_container: ApplicationContainer, db_session: TDatabaseFactory +) -> None: + sqla_research_context_repository = app_initialization_container.sqla_research_context_repository() + + new_vs_dto: NewVectorStoreDTO = sqla_research_context_repository.new_vector_store( + research_context_id=1, + embedding_model_id=1, + vector_store_name=None, # type: ignore + vector_store_lfn="test", + vector_store_protocol=ProtocolEnum.S3, + ) + + assert new_vs_dto.status == False + assert new_vs_dto.errorCode == -1 + assert new_vs_dto.errorMessage == "Vector Store name cannot be None" + assert new_vs_dto.errorName == "Vector Store name not provided" + assert new_vs_dto.errorType == "VectorStoreNameNotProvided" + + +def test_error_new_vector_store_vector_store_protocol_is_none( + app_initialization_container: ApplicationContainer, db_session: TDatabaseFactory +) -> None: + sqla_research_context_repository = app_initialization_container.sqla_research_context_repository() + + new_vs_dto: NewVectorStoreDTO = sqla_research_context_repository.new_vector_store( + research_context_id=1, + embedding_model_id=1, + vector_store_name="test", + vector_store_lfn="test", + vector_store_protocol=None, # type: ignore + ) + + assert new_vs_dto.status == False + assert new_vs_dto.errorCode == -1 + assert new_vs_dto.errorMessage == "Vector Store protocol cannot be None" + assert new_vs_dto.errorName == "Vector Store protocol not provided" + assert new_vs_dto.errorType == "VectorStoreProtocolNotProvided" + + +def test_error_new_vector_store_embedding_model_id_is_none( + app_initialization_container: ApplicationContainer, db_session: TDatabaseFactory +) -> None: + sqla_research_context_repository = app_initialization_container.sqla_research_context_repository() + + new_vs_dto: NewVectorStoreDTO = sqla_research_context_repository.new_vector_store( + research_context_id=1, + embedding_model_id=None, # type: ignore + vector_store_name="test", + vector_store_lfn="test", + vector_store_protocol=ProtocolEnum.S3, + ) + + assert new_vs_dto.status == False + assert new_vs_dto.errorCode == -1 + assert new_vs_dto.errorMessage == "Embedding Model ID cannot be None" + assert new_vs_dto.errorName == "Embedding Model ID not provided" + assert new_vs_dto.errorType == "EmbeddingModelIdNotProvided" + + +def test_error_new_vector_store_sqla_research_context_not_found( + app_initialization_container: ApplicationContainer, + db_session: TDatabaseFactory, + fake_embedding_model_with_vector_store: SQLAEmbeddingModel, +) -> None: + sqla_research_context_repository = app_initialization_container.sqla_research_context_repository() + + embedding_model = fake_embedding_model_with_vector_store + + with db_session() as session: + session.add(embedding_model) + session.commit() + embedding_model_id = embedding_model.id + + with db_session() as session: + irrealistic_ID = 99999999 + new_vs_dto: NewVectorStoreDTO = sqla_research_context_repository.new_vector_store( + research_context_id=irrealistic_ID, + embedding_model_id=embedding_model_id, + vector_store_name="test", + vector_store_lfn="test", + vector_store_protocol=ProtocolEnum.S3, + ) + + assert new_vs_dto.status == False + assert new_vs_dto.errorCode == -1 + assert new_vs_dto.errorMessage == f"Research Context with ID {irrealistic_ID} not found in the database." + assert new_vs_dto.errorName == "Research Context not found" + assert new_vs_dto.errorType == "ResearchContextNotFound" + + +def test_error_new_vector_store_sqla_embedding_model_not_found( + app_initialization_container: ApplicationContainer, + db_session: TDatabaseFactory, + fake_user_with_conversation: SQLAUser, + fake_llm: SQLALLM, +) -> None: + sqla_research_context_repository = app_initialization_container.sqla_research_context_repository() + + user_with_conv = fake_user_with_conversation + llm = fake_llm + llm.research_contexts = user_with_conv.research_contexts + + with db_session() as session: + session.add(user_with_conv) + session.commit() + research_context = user_with_conv.research_contexts[0] + research_context_id = research_context.id + + with db_session() as session: + irrealistic_ID = 99999999 + new_vs_dto: NewVectorStoreDTO = sqla_research_context_repository.new_vector_store( + research_context_id=research_context_id, + embedding_model_id=irrealistic_ID, + vector_store_name="test", + vector_store_lfn="test", + vector_store_protocol=ProtocolEnum.S3, + ) + + assert new_vs_dto.status == False + assert new_vs_dto.errorCode == -1 + assert new_vs_dto.errorMessage == f"Embedding Model with ID {irrealistic_ID} not found in the database." + assert new_vs_dto.errorName == "Embedding Model not found" + assert new_vs_dto.errorType == "EmbeddingModelNotFound"