Skip to content

Commit

Permalink
ResearchContextRepository: new_vector_store dream-aim-deliver#89
Browse files Browse the repository at this point in the history
- Declared in core, created DTO
- Implemented in SQLAResearchContextRepository
- Tested
  • Loading branch information
alebg committed Nov 20, 2023
1 parent 6b534b0 commit 723c851
Show file tree
Hide file tree
Showing 5 changed files with 501 additions and 3 deletions.
13 changes: 12 additions & 1 deletion lib/core/dto/research_context_repository_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List
from lib.core.entity.models import Conversation, ResearchContext, User
from lib.core.entity.models import Conversation, ResearchContext, User, VectorStore
from lib.core.sdk.dto import BaseDTO


Expand All @@ -13,6 +13,17 @@ class GetResearchContextDTO(BaseDTO[ResearchContext]):
data: ResearchContext | None = None


class NewVectorStoreDTO(BaseDTO[VectorStore]):
"""
A DTO for creating a new vector store
@param data: The vector store
@type data: VectorStore | None
"""

vector_store_id: int | None = None


class GetResearchContextUserDTO(BaseDTO[User]):
"""
A DTO for getting the user of a research context
Expand Down
25 changes: 25 additions & 0 deletions lib/core/ports/secondary/research_context_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
GetResearchContextUserDTO,
ListResearchContextConversationsDTO,
NewResearchContextConversationDTO,
NewVectorStoreDTO,
)
from lib.core.entity.models import ProtocolEnum


class ResearchContextRepositoryOutputPort(ABC):
Expand Down Expand Up @@ -36,6 +38,29 @@ def get_research_context(self, research_context_id: int) -> GetResearchContextDT
"""
raise NotImplementedError

@abstractmethod
def new_vector_store(
self,
research_context_id: int,
vector_store_lfn: str,
vector_store_name: str,
vector_store_protocol: ProtocolEnum,
embedding_model_id: int,
) -> NewVectorStoreDTO:
"""
Creates a new vector store in the research context.
@param research_context_id: The ID of the research context that is tied to the vector store.
@type research_context_id: int
@param vector_store_lfn: The LFN of the vector store.
@type vector_store_lfn: str
@param embedding_model_id: The ID of the embedding model to use for the vector store.
@type embedding_model_id: int
@return: A DTO containing the result of the operation.
@rtype: NewVectorStoreDTO
"""
raise NotImplementedError

@abstractmethod
def get_research_context_user(self, research_context_id: int) -> GetResearchContextUserDTO:
"""
Expand Down
150 changes: 148 additions & 2 deletions lib/infrastructure/repository/sqla/sqla_research_context_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@
GetResearchContextUserDTO,
ListResearchContextConversationsDTO,
NewResearchContextConversationDTO,
NewVectorStoreDTO,
)
from lib.core.entity.models import Conversation, ResearchContext
from lib.core.entity.models import Conversation, ProtocolEnum, ResearchContext
from lib.core.ports.secondary.research_context_repository import ResearchContextRepositoryOutputPort
from lib.infrastructure.repository.sqla.database import TDatabaseFactory
from sqlalchemy.orm import Session

from lib.infrastructure.repository.sqla.models import SQLAConversation, SQLAResearchContext, SQLAUser
from lib.infrastructure.repository.sqla.models import (
SQLAConversation,
SQLAEmbeddingModel,
SQLAResearchContext,
SQLAUser,
SQLAVectorStore,
)
from lib.infrastructure.repository.sqla.utils import (
convert_sqla_conversation_to_core_conversation,
convert_sqla_research_context_to_core_research_context,
Expand Down Expand Up @@ -69,6 +76,145 @@ def get_research_context(self, research_context_id: int) -> GetResearchContextDT

return GetResearchContextDTO(status=True, data=core_research_context)

def new_vector_store(
self,
research_context_id: int,
vector_store_lfn: str,
vector_store_name: str,
vector_store_protocol: ProtocolEnum,
embedding_model_id: int,
) -> NewVectorStoreDTO:
"""
Creates a new vector store in the research context.
@param research_context_id: The ID of the research context that is tied to the vector store.
@type research_context_id: int
@param vector_store_lfn: The LFN of the vector store.
@type vector_store_lfn: str
@param embedding_model_id: The ID of the embedding model to use for the vector store.
@type embedding_model_id: int
@return: A DTO containing the result of the operation.
@rtype: NewVectorStoreDTO
"""

if research_context_id is None:
self.logger.error(f"Research Context ID cannot be None")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage="Research Context ID cannot be None",
errorName="Research Context ID not provided",
errorType="ResearchContextIdNotProvided",
)
self.logger.error(f"{errorDTO}")
return errorDTO

if vector_store_lfn is None:
self.logger.error(f"Vector store LFN cannot be None")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage="Vector Store LFN cannot be None",
errorName="Vector Store LFN not provided",
errorType="VectorStoreLFNNotProvided",
)
self.logger.error(f"{errorDTO}")
return errorDTO

if vector_store_name is None:
self.logger.error(f"Vector store name cannot be None")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage="Vector Store name cannot be None",
errorName="Vector Store name not provided",
errorType="VectorStoreNameNotProvided",
)
self.logger.error(f"{errorDTO}")
return errorDTO

if vector_store_protocol is None:
self.logger.error(f"Vector store protocol cannot be None")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage="Vector Store protocol cannot be None",
errorName="Vector Store protocol not provided",
errorType="VectorStoreProtocolNotProvided",
)
self.logger.error(f"{errorDTO}")
return errorDTO

if embedding_model_id is None:
self.logger.error(f"Embedding model ID cannot be None")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage="Embedding Model ID cannot be None",
errorName="Embedding Model ID not provided",
errorType="EmbeddingModelIdNotProvided",
)
self.logger.error(f"{errorDTO}")
return errorDTO

queried_sqla_research_context: SQLAResearchContext | None = self.session.get(
SQLAResearchContext, research_context_id
)

if queried_sqla_research_context is None:
self.logger.error(f"Research context with ID {research_context_id} not found.")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage=f"Research Context with ID {research_context_id} not found in the database.",
errorName="Research Context not found",
errorType="ResearchContextNotFound",
)
self.logger.error(f"{errorDTO}")
return errorDTO

queried_sqla_embedding_model: SQLAEmbeddingModel | None = self.session.get(
SQLAEmbeddingModel, embedding_model_id
)

if queried_sqla_embedding_model is None:
self.logger.error(f"Embedding model with ID {embedding_model_id} not found.")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage=f"Embedding Model with ID {embedding_model_id} not found in the database.",
errorName="Embedding Model not found",
errorType="EmbeddingModelNotFound",
)
self.logger.error(f"{errorDTO}")
return errorDTO

sqla_new_vector_store: SQLAVectorStore = SQLAVectorStore(
name=vector_store_name,
lfn=vector_store_lfn,
protocol=vector_store_protocol,
embedding_model_id=embedding_model_id,
research_context_id=research_context_id,
)

try:
sqla_new_vector_store.save(session=self.session)
self.session.commit()

return NewVectorStoreDTO(status=True, vector_store_id=sqla_new_vector_store.id)

except Exception as e:
self.logger.error(f"Error while creating new vector store: {e}")
errorDTO = NewVectorStoreDTO(
status=False,
errorCode=-1,
errorMessage=f"Error while creating new vector store: {e}",
errorName="Error while creating new vector store",
errorType="ErrorWhileCreatingNewVectorStore",
)
self.logger.error(f"{errorDTO}")
return errorDTO

def get_research_context_user(self, research_context_id: int) -> GetResearchContextUserDTO:
"""
Gets the user of a research context.
Expand Down
83 changes: 83 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
SQLALLM,
SQLACitation,
SQLAConversation,
SQLAEmbeddingModel,
SQLAKnowledgeSource,
SQLAMessageBase,
SQLAMessageQuery,
SQLAMessageResponse,
SQLAResearchContext,
SQLASourceData,
SQLAUser,
SQLAVectorStore,
)
from tests.fixtures.factory.sqla_model_factory import SQLATemporaryModelFactory

Expand Down Expand Up @@ -328,6 +330,42 @@ def fake_user_with_conversation() -> SQLAUser:
return user_with_conversation()


def create_lfn() -> str:
fake = Faker().unique

protocols = [
attr_name.__str__().lower() for attr_name in vars(ProtocolEnum) if not attr_name.__str__().startswith("_")
]
knowledge_sources = [
attr_name.__str__().lower()
for attr_name in vars(KnowledgeSourceEnum)
if not attr_name.__str__().startswith("_")
]

sd_protocol_choice: str = random.choice(protocols)
sd_protocol = ProtocolEnum(sd_protocol_choice)
sd_host = fake.domain_name()
sd_port = fake.port_number()
sd_minio_bucket = fake.name()
sd_tracer_id = random.randint(1, 1000000000)
sd_knowledge_source_choice: str = random.choice(knowledge_sources)
sd_job_id = random.randint(1, 1000000000)

sd_filename = fake.file_name()
sd_name = sd_filename.split(".")[0]
sd_type = sd_filename.split(".")[1]
sd_relative_path = fake.file_path(depth=3).split(".")[0] + "/" + sd_filename

sd_lfn = f"{sd_protocol_choice}://{sd_host}:{sd_port}/{sd_minio_bucket}/{sd_tracer_id}/{sd_knowledge_source_choice}/{sd_job_id}{sd_relative_path}"

return sd_lfn


@pytest.fixture(scope="function")
def fake_lfn_list() -> List[str]:
return [create_lfn() for _ in range(10)]


def source_data() -> SQLASourceData:
fake = Faker().unique

Expand Down Expand Up @@ -444,3 +482,48 @@ def llm() -> SQLALLM:
@pytest.fixture(scope="function")
def fake_llm() -> SQLALLM:
return llm()


def embedding_model() -> SQLAEmbeddingModel:
fake = Faker().unique

fake_name = fake.name()

return SQLAEmbeddingModel(
name=fake_name,
)


@pytest.fixture(scope="function")
def fake_embedding_model() -> SQLAEmbeddingModel:
return embedding_model()


def vector_store() -> SQLAVectorStore:
fake = Faker().unique

lfn = create_lfn()
name = lfn.split("/")[-1]
protocol_str = lfn.split("://")[0]
protocol = ProtocolEnum(protocol_str)

return SQLAVectorStore(
name=name,
lfn=lfn,
protocol=protocol,
)


def embedding_model_with_vector_store() -> SQLAEmbeddingModel:
em = embedding_model()

vector_store_list = [vector_store() for _ in range(10)]

em.vector_stores = vector_store_list

return em


@pytest.fixture(scope="function")
def fake_embedding_model_with_vector_store() -> SQLAEmbeddingModel:
return embedding_model_with_vector_store()
Loading

0 comments on commit 723c851

Please sign in to comment.