Skip to content

Commit

Permalink
refactor: get rid of langchain fully
Browse files Browse the repository at this point in the history
  • Loading branch information
umbertogriffo committed Sep 7, 2024
1 parent d10c036 commit 0763052
Show file tree
Hide file tree
Showing 15 changed files with 692 additions and 1,068 deletions.
2 changes: 1 addition & 1 deletion chatbot/bot/conversation/conversation_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from asyncio import get_event_loop
from typing import Any, List, Tuple

from entities.document import Document
from helpers.log import get_logger
from langchain_core.documents import Document

from bot.client.lama_cpp_client import LamaCppClient
from bot.conversation.ctx_strategy import AsyncTreeSummarizationStrategy, BaseSynthesisStrategy
Expand Down
2 changes: 1 addition & 1 deletion chatbot/bot/conversation/ctx_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Any, List, Union

import nest_asyncio
from entities.document import Document
from helpers.log import get_logger
from langchain_core.documents import Document

from bot.client.lama_cpp_client import LamaCppClient

Expand Down
82 changes: 73 additions & 9 deletions chatbot/bot/memory/embedder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,80 @@
from abc import ABC
from abc import ABC, abstractmethod
from typing import Any

from langchain.embeddings import HuggingFaceEmbeddings


class Embedder(ABC):
embedder: Any
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""

@abstractmethod
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""


class HuggingFaceEmbedder(Embedder):
"""HuggingFace sentence_transformers embedding models.
To use, you should have the ``sentence_transformers`` python package installed.
"""

client: Any #: :meta private:
model_name: str = "all-MiniLM-L6-v2"
"""Model name to use."""
cache_folder: str | None = None
"""Path to store models.
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
model_kwargs: dict[str, Any] = {}
"""Keyword arguments to pass to the model."""
encode_kwargs: dict[str, Any] = {}
"""Keyword arguments to pass when calling the `encode` method of the model."""
multi_process: bool = False
"""Run encode() on multiple GPUs."""

def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
try:
import sentence_transformers

except ImportError as exc:
raise ImportError(
"Could not import sentence_transformers python package. "
"Please install it with `pip install sentence-transformers`."
) from exc

self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)

def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
import sentence_transformers

texts = list(map(lambda x: x.replace("\n", " "), texts))
if self.multi_process:
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(texts, pool)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else:
embeddings = self.client.encode(texts, **self.encode_kwargs)

return embeddings.tolist()

def get_embedding(self):
return self.embedder
def embed_query(self, text: str) -> list[float]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
class EmbedderHuggingFace(Embedder):
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
self.embedder = HuggingFaceEmbeddings(model_name=model_name)
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]
7 changes: 3 additions & 4 deletions chatbot/bot/memory/vector_memory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Tuple

from cleantext import clean
from entities.document import Document
from helpers.log import get_logger
from langchain.vectorstores import Chroma
from langchain_core.documents import Document
from vector_database.chroma import Chroma

logger = get_logger(__name__)

Expand Down Expand Up @@ -99,10 +99,9 @@ def similarity_search(
def create_memory_index(embedding: Any, chunks: List, vector_store_path: str):
texts = [clean(doc.page_content, no_emoji=True) for doc in chunks]
metadatas = [doc.metadata for doc in chunks]
memory_index = Chroma.from_texts(
Chroma.from_texts(
texts=texts,
embedding=embedding,
metadatas=metadatas,
persist_directory=vector_store_path,
)
memory_index.persist()
4 changes: 2 additions & 2 deletions chatbot/cli/rag_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bot.client.lama_cpp_client import LamaCppClient
from bot.conversation.conversation_retrieval import ConversationRetrieval
from bot.conversation.ctx_strategy import get_ctx_synthesis_strategies, get_ctx_synthesis_strategy
from bot.memory.embedder import EmbedderHuggingFace
from bot.memory.embedder import HuggingFaceEmbedder
from bot.memory.vector_memory import VectorMemory
from bot.model.model_settings import get_model_setting, get_models
from helpers.log import get_logger
Expand Down Expand Up @@ -135,7 +135,7 @@ def main(parameters):

conversation = ConversationRetrieval(llm)

embedding = EmbedderHuggingFace().get_embedding()
embedding = HuggingFaceEmbedder()
index = VectorMemory(vector_store_path=str(vector_store_path), embedding=embedding)

loop(conversation, synthesis_strategy, index, parameters)
Expand Down
70 changes: 70 additions & 0 deletions chatbot/document_loader/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from enum import Enum


class Format(Enum):
MARKDOWN = "markdown"
HTML = "html"


SUPPORTED_FORMATS = {
Format.MARKDOWN.value: [
# First, try to split along Markdown headings (starting with level 2)
"\n#{1,6} ",
# Note the alternative syntax for headings (below) is not handled here
# Heading level 2
# ---------------
# End of code block
"```\n",
# Horizontal lines
"\n\\*\\*\\*+\n",
"\n---+\n",
"\n___+\n",
# Note that this splitter doesn't handle horizontal lines defined
# by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
],
Format.HTML.value: [
# First, try to split along HTML tags
"<body",
"<div",
"<p",
"<br",
"<li",
"<h1",
"<h2",
"<h3",
"<h4",
"<h5",
"<h6",
"<span",
"<table",
"<tr",
"<td",
"<th",
"<ul",
"<ol",
"<header",
"<footer",
"<nav",
# Head
"<head",
"<style",
"<script",
"<meta",
"<title",
"",
],
}


def get_separators(format: str):
separators = SUPPORTED_FORMATS.get(format)

# validate input
if separators is None:
raise KeyError(format + " is a not supported format")

return separators
74 changes: 6 additions & 68 deletions chatbot/document_loader/text_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,13 @@
import logging
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, Iterable

from entities.document import Document

logger = logging.getLogger(__name__)


class Format(str, Enum):
MARKDOWN = "markdown"
HTML = "html"

from document_loader.format import get_separators

def __get_separators(format: Format) -> list[str]:
if format == Format.MARKDOWN:
return [
# First, try to split along Markdown headings (starting with level 2)
"\n#{1,6} ",
# Note the alternative syntax for headings (below) is not handled here
# Heading level 2
# ---------------
# End of code block
"```\n",
# Horizontal lines
"\n\\*\\*\\*+\n",
"\n---+\n",
"\n___+\n",
# Note that this splitter doesn't handle horizontal lines defined
# by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
]
elif format == Format.HTML:
return [
# First, try to split along HTML tags
"<body",
"<div",
"<p",
"<br",
"<li",
"<h1",
"<h2",
"<h3",
"<h4",
"<h5",
"<h6",
"<span",
"<table",
"<tr",
"<td",
"<th",
"<ul",
"<ol",
"<header",
"<footer",
"<nav",
# Head
"<head",
"<style",
"<script",
"<meta",
"<title",
"",
]
else:
raise ValueError(f"Language {format} is not supported! " f"Please choose from {list(Format)}")
logger = logging.getLogger(__name__)


class TextSplitter(ABC):
Expand Down Expand Up @@ -157,10 +96,9 @@ def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)

docs = []
current_doc: list[str] = []
docs, current_doc = [], []
total = 0

for d in splits:
_len = self._length_function(d)
if total + _len + (separator_len if len(current_doc) > 0 else 0) > self._chunk_size:
Expand Down Expand Up @@ -281,7 +219,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l
return [s for s in splits if s != ""]


def create_recursive_text_splitter(format: Format, **kwargs: Any) -> RecursiveCharacterTextSplitter:
def create_recursive_text_splitter(format: str, **kwargs: Any) -> RecursiveCharacterTextSplitter:
"""
Factory function to create a RecursiveCharacterTextSplitter instance based on the specified format.
Expand All @@ -292,5 +230,5 @@ def create_recursive_text_splitter(format: Format, **kwargs: Any) -> RecursiveCh
Returns:
An instance of RecursiveCharacterTextSplitter configured with the appropriate separators.
"""
separators = __get_separators(format)
separators = get_separators(format)
return RecursiveCharacterTextSplitter(separators=separators, **kwargs)
4 changes: 2 additions & 2 deletions chatbot/memory_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import List

from bot.memory.embedder import EmbedderHuggingFace
from bot.memory.embedder import HuggingFaceEmbedder
from bot.memory.vector_memory import VectorMemory
from document_loader.loader import DirectoryLoader
from document_loader.text_splitter import Format, create_recursive_text_splitter
Expand Down Expand Up @@ -62,7 +62,7 @@ def build_memory_index(docs_path: Path, vector_store_path: str, chunk_size: int,
logger.info(f"Number of generated chunks: {len(chunks)}")

logger.info("Creating memory index...")
embedding = EmbedderHuggingFace().get_embedding()
embedding = HuggingFaceEmbedder()
VectorMemory.create_memory_index(embedding, chunks, vector_store_path)
logger.info("Memory Index has been created successfully!")

Expand Down
4 changes: 2 additions & 2 deletions chatbot/rag_chatbot_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_ctx_synthesis_strategies,
get_ctx_synthesis_strategy,
)
from bot.memory.embedder import EmbedderHuggingFace
from bot.memory.embedder import HuggingFaceEmbedder
from bot.memory.vector_memory import VectorMemory
from bot.model.model_settings import get_model_setting, get_models
from helpers.log import get_logger
Expand Down Expand Up @@ -51,7 +51,7 @@ def load_index(vector_store_path: Path) -> VectorMemory:
Returns:
VectorMemory: An instance of the VectorMemory class with the loaded index.
"""
embedding = EmbedderHuggingFace().get_embedding()
embedding = HuggingFaceEmbedder()
index = VectorMemory(vector_store_path=str(vector_store_path), embedding=embedding)

return index
Expand Down
Loading

0 comments on commit 0763052

Please sign in to comment.