Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
Athe-kunal committed Sep 28, 2024
2 parents c547153 + b26cf34 commit 0e11a7f
Show file tree
Hide file tree
Showing 10 changed files with 315 additions and 76 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Loading a model with `byaldi` is extremely straightforward:
```python3
from byaldi import RAGMultiModalModel
# Optionally, you can specify an `index_root`, which is where it'll save the index. It defaults to ".byaldi/".
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
```

If you've already got an index, and wish to load it along with the model necessary to query it, you can do so just as easily:
Expand All @@ -77,7 +77,7 @@ Creating an index with `byaldi` is simple and flexible. **You can index a single
```python3
from byaldi import RAGMultiModalModel
# Optionally, you can specify an `index_root`, which is where it'll save the index. It defaults to ".byaldi/".
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
RAG.index(
input_path="docs/", # The path to your documents
index_name=index_name, # The name you want to give to your index. It'll be saved at `index_root/index_name/`.
Expand Down
19 changes: 16 additions & 3 deletions byaldi/RAGModel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from PIL import Image

from byaldi.colpali import ColPaliModel

from byaldi.objects import Result

# Optional langchain integration
try:
from byaldi.integrations import ByaldiLangChainRetriever
except ImportError:
pass


class RAGMultiModalModel:
"""
Expand All @@ -19,7 +26,7 @@ class RAGMultiModalModel:
```python
from byaldi import RAGMultiModalModel
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
RAG = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2")
```
Both methods will load a fully initialised instance of ColPali, which you can use to build and query indexes.
Expand Down Expand Up @@ -50,7 +57,10 @@ def from_pretrained(
"""
instance = cls()
instance.model = ColPaliModel.from_pretrained(
pretrained_model_name_or_path, index_root=index_root, device=device, verbose=verbose
pretrained_model_name_or_path,
index_root=index_root,
device=device,
verbose=verbose,
)
return instance

Expand Down Expand Up @@ -166,3 +176,6 @@ def search(

def get_doc_ids_to_file_names(self):
return self.model.get_doc_ids_to_file_names()

def as_langchain_retriever(self, **kwargs: Any):
return ByaldiLangChainRetriever(model=self, kwargs=kwargs)
106 changes: 53 additions & 53 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,22 @@
import os
import shutil

# Import version directly from the package metadata
import tempfile
from importlib.metadata import version
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, cast

import srsly
import torch
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
process_images,
process_queries,
)
from colpali_engine.models import ColPali, ColPaliProcessor
from pdf2image import convert_from_path
from PIL import Image
from transformers import AutoProcessor

from byaldi.objects import Result

from .utils import capture_print

# Import version directly from the package metadata
VERSION = version("Byaldi")


MOCK_IMAGE = Image.new("RGB", (448, 448), (255, 255, 255))


class ColPaliModel:
def __init__(
self,
Expand All @@ -40,6 +29,9 @@ def __init__(
device: Optional[Union[str, torch.device]] = None,
**kwargs,
):
if isinstance(pretrained_model_name_or_path, Path):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)

if "colpali" not in pretrained_model_name_or_path.lower():
raise ValueError(
"This pre-release version of Byaldi only supports ColPali for now. Incorrect model name specified."
Expand Down Expand Up @@ -72,35 +64,27 @@ def __init__(
self.doc_ids_to_file_names = {}
self.doc_ids = set()

# self.model = ColPali.from_pretrained(
# "vidore/colpaligemma-3b-pt-448-base",
# torch_dtype=torch.bfloat16,
# device_map="cuda"
# if device == "cuda"
# or (isinstance(device, torch.device) and device.type == "cuda")
# else None,
# token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
# )

# if verbose > 0:
# print("Loading adapter...")
# print("Adapter name: ", self.pretrained_model_name_or_path)
# self.model.load_adapter(self.pretrained_model_name_or_path)

self.model = ColPali.from_pretrained(
self.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
device_map="cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None,
device_map=(
"cuda"
if device == "cuda"
or (isinstance(device, torch.device) and device.type == "cuda")
else None
),
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
)
self.model = self.model.eval()
self.processor = AutoProcessor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),

self.processor = cast(
ColPaliProcessor,
ColPaliProcessor.from_pretrained(
self.pretrained_model_name_or_path,
token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"),
),
)

self.device = device
if device != "cuda" and not (
isinstance(device, torch.device) and device.type == "cuda"
Expand All @@ -111,14 +95,18 @@ def __init__(
self.full_document_collection = False
self.highest_doc_id = -1
else:
index_path = Path(index_root) / Path(index_name)
if self.index_name is None:
raise ValueError("No index name specified. Cannot load from index.")

index_path = Path(index_root) / Path(self.index_name)
index_config = srsly.read_gzip_json(index_path / "index_config.json.gz")
self.full_document_collection = index_config.get(
"full_document_collection", False
)
self.resize_stored_images = index_config.get("resize_stored_images", False)
self.max_image_width = index_config.get("max_image_width", None)
self.max_image_height = index_config.get("max_image_height", None)

if self.full_document_collection:
collection_path = index_path / "collection"
json_files = sorted(
Expand Down Expand Up @@ -473,15 +461,22 @@ def _process_and_add_to_index(
"""TODO: THERE ARE TOO MANY FUNCTIONS DOING THINGS HERE. I blame Claude, but this is temporary anyway."""
if isinstance(item, Path):
if item.suffix.lower() == ".pdf":
images = convert_from_path(item)
for i, image in enumerate(images):
self._add_to_index(
image,
store_collection_with_index,
doc_id,
page_id=i + 1,
metadata=metadata,
with tempfile.TemporaryDirectory() as path:
images = convert_from_path(
item,
thread_count=os.cpu_count()-1,
output_folder=path,
paths_only=True
)
for i, image_path in enumerate(images):
image = Image.open(image_path)
self._add_to_index(
image,
store_collection_with_index,
doc_id,
page_id=i + 1,
metadata=metadata,
)
elif item.suffix.lower() in [".jpg", ".jpeg", ".png", ".bmp"]:
image = Image.open(item)
self._add_to_index(
Expand Down Expand Up @@ -512,7 +507,7 @@ def _add_to_index(
f"Document ID {doc_id} with page ID {page_id} already exists in the index"
)

processed_image = process_images(self.processor, [image])
processed_image = self.processor.process_images([image])

# Generate embedding
with torch.no_grad():
Expand Down Expand Up @@ -614,7 +609,7 @@ def search(
for q in queries:
# Process query
with torch.no_grad():
batch_query = process_queries(self.processor, [q], MOCK_IMAGE)
batch_query = self.processor.process_queries([q])
batch_query = {k: v.to(self.device) for k, v in batch_query.items()}
embeddings_query = self.model(**batch_query)
qs = list(torch.unbind(embeddings_query.to("cpu")))
Expand All @@ -623,7 +618,7 @@ def search(
else:
req_embeddings, req_embedding_ids = self.filter_embeddings(filter_metadata=filter_metadata)
# Compute scores
scores = self._score(qs,req_embeddings)
scores = self.processor.score(qs,req_embeddings, self.indexed_embeddings).cpu().numpy()

# Get top k relevant pages
top_pages = scores.argsort(axis=1)[0][-k:][::-1].tolist()
Expand Down Expand Up @@ -681,8 +676,13 @@ def encode_image(
images.append(Image.open(os.path.join(item, file)))
elif item.lower().endswith(".pdf"):
# Process PDF
pdf_images = convert_from_path(item)
images.extend(pdf_images)
with tempfile.TemporaryDirectory() as path:
pdf_images = convert_from_path(
item,
thread_count=os.cpu_count()-1,
output_folder=path
)
images.extend(pdf_images)
elif item.lower().endswith(
(".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif")
):
Expand All @@ -694,7 +694,7 @@ def encode_image(
raise ValueError(f"Unsupported input type: {type(item)}")

with torch.no_grad():
batch = process_images(self.processor, images)
batch = self.processor.process_images(images)
batch = {k: v.to(self.device) for k, v in batch.items()}
embeddings = self.model(**batch)

Expand All @@ -715,7 +715,7 @@ def encode_query(self, query: Union[str, List[str]]) -> torch.Tensor:
query = [query]

with torch.no_grad():
batch = process_queries(self.processor, query, MOCK_IMAGE)
batch = self.processor.process_queries(query)
batch = {k: v.to(self.device) for k, v in batch.items()}
embeddings = self.model(**batch)

Expand Down
8 changes: 8 additions & 0 deletions byaldi/integrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_all__ = []

try:
from byaldi.integrations._langchain import ByaldiLangChainRetriever

_all__.append("ByaldiLangChainRetriever")
except ImportError:
pass
21 changes: 21 additions & 0 deletions byaldi/integrations/_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any, List

from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever

from byaldi.objects import Result


class ByaldiLangChainRetriever(BaseRetriever):
model: Any
kwargs: dict = {}

def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun, # noqa
) -> List[Result]:
"""Get documents relevant to a query."""
docs = self.model.search(query, **self.kwargs)
return docs
15 changes: 0 additions & 15 deletions byaldi/utils.py

This file was deleted.

7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,28 @@ maintainers = [
]

dependencies = [
"colpali-engine==0.2.2",
"colpali-engine>=0.3.0,<0.4.0",
"ml-dtypes",
"mteb==1.6.35",
"ninja",
"pdf2image",
"srsly",
"torch",
"transformers",
"transformers>=4.42.0",
]

[project.optional-dependencies]
dev = ["pytest>=7.4.0", "ruff>=0.1.9"]
server = ["uvicorn", "fastapi"]
langchain = ["langchain-core"]

[project.urls]
"Homepage" = "https://github.com/answerdotai/byaldi"

[tool.pytest.ini_options]
filterwarnings = ["ignore::Warning"]
markers = ["slow: marks test as slow"]
testpaths = ["tests"]

[tool.ruff]
# Exclude a variety of commonly ignored directories.
Expand Down
Loading

0 comments on commit 0e11a7f

Please sign in to comment.