Skip to content

Commit

Permalink
feat: save and load brain (#3202)
Browse files Browse the repository at this point in the history
# Description
- Save and load brain to disk: 
```python
async def main():
    with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
        temp_file.write("Gold is a liquid of blue-like colour.")
        temp_file.flush()

        brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])

        save_path = await brain.save("/home/amine/.local/quivr")

        brain_loaded = Brain.load(save_path)
        brain_loaded.print_info()

```

# TODO: 
- Loading all chat history
- Loading from other vector stores, PG for example can be great ...
  • Loading branch information
AmineDiro committed Sep 13, 2024
1 parent 06f72eb commit eda619f
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 62 deletions.
22 changes: 22 additions & 0 deletions backend/core/examples/save_load_brain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import asyncio
import tempfile

from quivr_core import Brain


async def main():
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is a liquid of blue-like colour.")
temp_file.flush()

brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])

save_path = await brain.save("/home/amine/.local/quivr")

brain_loaded = Brain.load(save_path)
brain_loaded.print_info()


if __name__ == "__main__":
# Run the main function in the existing event loop
asyncio.run(main())
21 changes: 11 additions & 10 deletions backend/core/examples/simple_question.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
import tempfile

from quivr_core import Brain
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph

if __name__ == "__main__":
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is a liquid of blue-like colour.")
temp_file.flush()

brain = Brain.from_files(name="test_brain",
file_paths=[temp_file.name],
)
brain = Brain.from_files(
name="test_brain",
file_paths=[temp_file.name],
)

answer = brain.ask("what is gold? asnwer in french",
rag_pipeline=QuivrQARAGLangGraph)
answer = brain.ask(
"what is gold? asnwer in french", rag_pipeline=QuivrQARAGLangGraph
)
print("answer QuivrQARAGLangGraph :", answer.answer)


answer = brain.ask("what is gold? asnwer in french",
rag_pipeline=QuivrQARAG)
print("answer QuivrQARAG :", answer.answer)
answer = brain.ask("what is gold? asnwer in french", rag_pipeline=QuivrQARAG)
print("answer QuivrQARAG :", answer.answer)
23 changes: 14 additions & 9 deletions backend/core/examples/simple_question_streaming.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
from dotenv import load_dotenv
import tempfile
import asyncio
import tempfile

from dotenv import load_dotenv
from quivr_core import Brain
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph


async def main():
dotenv_path = "/Users/jchevall/Coding/QuivrHQ/quivr/.env"
load_dotenv(dotenv_path)

with tempfile.NamedTemporaryFile(mode="w", suffix=".txt") as temp_file:
temp_file.write("Gold is a liquid of blue-like colour.")
temp_file.flush()

brain = await Brain.afrom_files(name="test_brain",
file_paths=[temp_file.name])
brain = await Brain.afrom_files(name="test_brain", file_paths=[temp_file.name])

await brain.save("~/.local/quivr")

question = "what is gold? answer in french"
async for chunk in brain.ask_streaming(question, rag_pipeline=QuivrQARAG):
print("answer QuivrQARAG:", chunk.answer)
print("answer QuivrQARAG:", chunk.answer)

async for chunk in brain.ask_streaming(question, rag_pipeline=QuivrQARAGLangGraph):
async for chunk in brain.ask_streaming(
question, rag_pipeline=QuivrQARAGLangGraph
):
print("answer QuivrQARAGLangGraph:", chunk.answer)


if __name__ == "__main__":
# Run the main function in the existing event loop
asyncio.run(main())
asyncio.run(main())
121 changes: 116 additions & 5 deletions backend/core/quivr_core/brain/brain.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,36 @@
import asyncio
import logging
import os
from pathlib import Path
from pprint import PrettyPrinter
from typing import Any, AsyncGenerator, Callable, Dict, Self, Union, Type
from typing import Any, AsyncGenerator, Callable, Dict, Self, Type, Union
from uuid import UUID, uuid4

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.vectorstores import VectorStore
from langchain_openai import OpenAIEmbeddings
from rich.console import Console
from rich.panel import Panel

from quivr_core.brain.info import BrainInfo, ChatHistoryInfo
from quivr_core.brain.serialization import (
BrainSerialized,
EmbedderConfig,
FAISSConfig,
LocalStorageConfig,
TransparentStorageConfig,
)
from quivr_core.chat import ChatHistory
from quivr_core.config import RAGConfig
from quivr_core.files.file import load_qfile
from quivr_core.llm import LLMEndpoint
from quivr_core.models import ParsedRAGChunkResponse, ParsedRAGResponse, SearchResult
from quivr_core.processor.registry import get_processor_class
from quivr_core.quivr_rag import QuivrQARAG
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.storage.local_storage import TransparentStorage
from quivr_core.quivr_rag_langgraph import QuivrQARAGLangGraph
from quivr_core.storage.local_storage import LocalStorage, TransparentStorage
from quivr_core.storage.storage_base import StorageBase

from .brain_defaults import build_default_vectordb, default_embedder, default_llm
Expand Down Expand Up @@ -90,6 +99,108 @@ def print_info(self):
panel = Panel(tree, title="Brain Info", expand=False, border_style="bold")
console.print(panel)

@classmethod
def load(cls, folder_path: str | Path) -> Self:
if isinstance(folder_path, str):
folder_path = Path(folder_path)
if not folder_path.exists():
raise ValueError(f"path {folder_path} doesn't exist")

# Load brainserialized
with open(os.path.join(folder_path, "config.json"), "r") as f:
bserialized = BrainSerialized.model_validate_json(f.read())

# Loading storage
if bserialized.storage_config.storage_type == "transparent_storage":
storage: StorageBase = TransparentStorage.load(bserialized.storage_config)
elif bserialized.storage_config.storage_type == "local_storage":
storage: StorageBase = LocalStorage.load(bserialized.storage_config)
else:
raise ValueError("unknown storage")

# Load Embedder
if bserialized.embedding_config.embedder_type == "openai_embedding":
from langchain_openai import OpenAIEmbeddings

embedder = OpenAIEmbeddings(**bserialized.embedding_config.config)
else:
raise ValueError("unknown embedder")

# Load vector db
if bserialized.vectordb_config.vectordb_type == "faiss":
from langchain_community.vectorstores import FAISS

vector_db = FAISS.load_local(
folder_path=bserialized.vectordb_config.vectordb_folder_path,
embeddings=embedder,
allow_dangerous_deserialization=True,
)
else:
raise ValueError("Unsupported vectordb")

return cls(
id=bserialized.id,
name=bserialized.name,
embedder=embedder,
llm=LLMEndpoint.from_config(bserialized.llm_config),
storage=storage,
vector_db=vector_db,
)

async def save(self, folder_path: str | Path):
if isinstance(folder_path, str):
folder_path = Path(folder_path)

brain_path = os.path.join(folder_path, f"brain_{self.id}")
os.makedirs(brain_path, exist_ok=True)

from langchain_community.vectorstores import FAISS

if isinstance(self.vector_db, FAISS):
vectordb_path = os.path.join(brain_path, "vector_store")
os.makedirs(vectordb_path, exist_ok=True)
self.vector_db.save_local(folder_path=vectordb_path)
vector_store = FAISSConfig(vectordb_folder_path=vectordb_path)
else:
raise Exception("can't serialize other vector stores for now")

if isinstance(self.embedder, OpenAIEmbeddings):
embedder_config = EmbedderConfig(
config=self.embedder.dict(exclude={"openai_api_key"})
)
else:
raise Exception("can't serialize embedder other than openai for now")

# TODO : each instance should know how to serialize/deserialize itself
if isinstance(self.storage, LocalStorage):
serialized_files = {
f.id: f.serialize() for f in await self.storage.get_files()
}
storage_config = LocalStorageConfig(
storage_path=self.storage.dir_path, files=serialized_files
)
elif isinstance(self.storage, TransparentStorage):
serialized_files = {
f.id: f.serialize() for f in await self.storage.get_files()
}
storage_config = TransparentStorageConfig(files=serialized_files)
else:
raise Exception("can't serialize storage. not supported for now")

bserialized = BrainSerialized(
id=self.id,
name=self.name,
chat_history=self.chat_history.get_chat_history(),
llm_config=self.llm.get_config(),
vectordb_config=vector_store,
embedding_config=embedder_config,
storage_config=storage_config,
)

with open(os.path.join(brain_path, "config.json"), "w") as f:
f.write(bserialized.model_dump_json())
return brain_path

def info(self) -> BrainInfo:
# TODO: dim of embedding
# "embedder": {},
Expand Down Expand Up @@ -177,7 +288,7 @@ def from_files(
storage: StorageBase = TransparentStorage(),
llm: LLMEndpoint | None = None,
embedder: Embeddings | None = None,
skip_file_error: bool = False
skip_file_error: bool = False,
) -> Self:
loop = asyncio.get_event_loop()
return loop.run_until_complete(
Expand Down Expand Up @@ -223,7 +334,7 @@ async def afrom_langchain_documents(
storage=storage,
llm=llm,
embedder=embedder,
vector_db=vector_db
vector_db=vector_db,
)

async def asearch(
Expand Down
1 change: 0 additions & 1 deletion backend/core/quivr_core/brain/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def add_to_tree(self, llm_tree: Tree):
llm_tree.add(f"Base URL: [underline]{self.llm_base_url}[/underline]")
llm_tree.add(f"Temperature: [bold]{self.temperature}[/bold]")
llm_tree.add(f"Max Tokens: [bold]{self.max_tokens}[/bold]")

func_call_color = "green" if self.supports_function_calling else "red"
llm_tree.add(
f"Supports Function Calling: [bold {func_call_color}]{self.supports_function_calling}[/bold {func_call_color}]"
Expand Down
55 changes: 55 additions & 0 deletions backend/core/quivr_core/brain/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from pathlib import Path
from typing import Any, Dict, Literal, Union
from uuid import UUID

from pydantic import BaseModel, Field, SecretStr

from quivr_core.config import LLMEndpointConfig
from quivr_core.files.file import QuivrFileSerialized
from quivr_core.models import ChatMessage


class EmbedderConfig(BaseModel):
embedder_type: Literal["openai_embedding"] = "openai_embedding"
# TODO: type this correctly
config: Dict[str, Any]


class PGVectorConfig(BaseModel):
vectordb_type: Literal["pgvector"] = "pgvector"
pg_url: str
pg_user: str
pg_psswd: SecretStr
table_name: str
vector_dim: int


class FAISSConfig(BaseModel):
vectordb_type: Literal["faiss"] = "faiss"
vectordb_folder_path: str


class LocalStorageConfig(BaseModel):
storage_type: Literal["local_storage"] = "local_storage"
storage_path: Path
files: dict[UUID, QuivrFileSerialized]


class TransparentStorageConfig(BaseModel):
storage_type: Literal["transparent_storage"] = "transparent_storage"
files: dict[UUID, QuivrFileSerialized]


class BrainSerialized(BaseModel):
id: UUID
name: str
chat_history: list[ChatMessage]
vectordb_config: Union[FAISSConfig, PGVectorConfig] = Field(
..., discriminator="vectordb_type"
)
storage_config: Union[TransparentStorageConfig, LocalStorageConfig] = Field(
..., discriminator="storage_type"
)

llm_config: LLMEndpointConfig
embedding_config: EmbedderConfig
39 changes: 38 additions & 1 deletion backend/core/quivr_core/files/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@
from contextlib import asynccontextmanager
from enum import Enum
from pathlib import Path
from typing import Any, AsyncGenerator, AsyncIterable
from typing import Any, AsyncGenerator, AsyncIterable, Self
from uuid import UUID, uuid4

import aiofiles
from openai import BaseModel


class QuivrFileSerialized(BaseModel):
id: UUID
brain_id: UUID
path: Path
original_filename: str
file_size: int | None
file_extension: str
file_sha1: str
additional_metadata: dict[str, Any]


class FileExtension(str, Enum):
Expand Down Expand Up @@ -137,3 +149,28 @@ def metadata(self) -> dict[str, Any]:
"file_size": self.file_size,
**self.additional_metadata,
}

def serialize(self) -> QuivrFileSerialized:
return QuivrFileSerialized(
id=self.id,
brain_id=self.brain_id,
path=self.path.absolute(),
original_filename=self.original_filename,
file_size=self.file_size,
file_extension=self.file_extension,
file_sha1=self.file_sha1,
additional_metadata=self.additional_metadata,
)

@classmethod
def deserialize(cls, serialized: QuivrFileSerialized) -> Self:
return cls(
id=serialized.id,
brain_id=serialized.brain_id,
path=serialized.path,
original_filename=serialized.original_filename,
file_size=serialized.file_size,
file_extension=serialized.file_extension,
file_sha1=serialized.file_sha1,
metadata=serialized.additional_metadata,
)
Loading

0 comments on commit eda619f

Please sign in to comment.