Skip to content

Commit

Permalink
Merge pull request #12 from AnswerDotAI/feat/resize
Browse files Browse the repository at this point in the history
Feat/resize
  • Loading branch information
bclavie authored Sep 9, 2024
2 parents d6349d2 + 96de3d4 commit 8d29b24
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
6 changes: 6 additions & 0 deletions byaldi/RAGModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def index(
List[Dict[str, Union[str, int]]],
]
] = None,
max_image_width: Optional[int] = None,
max_image_height: Optional[int] = None,
**kwargs,
):
"""Build an index from input documents.
Expand All @@ -115,6 +118,9 @@ def index(
store_collection_with_index,
overwrite=overwrite,
metadata=metadata,
max_image_width=max_image_width,
max_image_height=max_image_height,
**kwargs,
)

def add_to_index(
Expand Down
42 changes: 35 additions & 7 deletions byaldi/colpali.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
)

if verbose > 0:
print(f"Verbosity is set to {verbose} (active). Pass verbose=0 to make quieter.")
print(f"Verbosity is set to {verbose} ({'active' if verbose == 1 else 'loud'}). Pass verbose=0 to make quieter.")

self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.model_name = self.pretrained_model_name_or_path
Expand Down Expand Up @@ -107,6 +107,9 @@ def __init__(
self.full_document_collection = index_config.get(
"full_document_collection", False
)
self.resize_stored_images = index_config.get("resize_stored_images", False)
self.max_image_width = index_config.get("max_image_width", None)
self.max_image_height = index_config.get("max_image_height", None)
if self.full_document_collection:
collection_path = index_path / "collection"
json_files = sorted(
Expand Down Expand Up @@ -185,7 +188,6 @@ def from_index(
):
index_path = Path(index_root) / Path(index_path)
index_config = srsly.read_gzip_json(index_path / "index_config.json.gz")
print(index_config)

instance = cls(
pretrained_model_name_or_path=index_config["model_name"],
Expand Down Expand Up @@ -221,6 +223,9 @@ def _export_index(self):
"model_name": self.model_name,
"full_document_collection": self.full_document_collection,
"highest_doc_id": self.highest_doc_id,
"resize_stored_images": True if self.max_image_width and self.max_image_height else False,
"max_image_width": self.max_image_width,
"max_image_height": self.max_image_height,
"library_version": VERSION,
}
srsly.write_gzip_json(index_path / "index_config.json.gz", index_config)
Expand Down Expand Up @@ -253,13 +258,15 @@ def index(
store_collection_with_index: bool = False,
overwrite: bool = False,
metadata: Optional[List[Dict[str, Union[str, int]]]] = None,
max_image_width: Optional[int] = None,
max_image_height: Optional[int] = None,
) -> Dict[int, str]:
if (
self.index_name is not None
and (index_name is None or self.index_name == index_name)
and not overwrite
):
print(
raise ValueError(
f"An index named {self.index_name} is already loaded.",
"Use add_to_index() to add to it or search() to query it.",
"Pass a new index_name to create a new index.",
Expand All @@ -274,18 +281,20 @@ def index(
index_path = Path(self.index_root) / Path(index_name)
if index_path.exists():
if overwrite is False:
print(f"An index named {index_name} already exists.")
print("Use overwrite=True to delete the existing index and build a new one.")
print("Exiting indexing without doing anything...")
raise ValueError(f"An index named {index_name} already exists.",
"Use overwrite=True to delete the existing index and build a new one.",
"Exiting indexing without doing anything...")
return None
else:
print(f"overwrite is on. Deleting existing index {index_name} to build a new one.")
shutil.rmtree(index_path)

self.index_name = index_name
self.max_image_width = max_image_width
self.max_image_height = max_image_height

input_path = Path(input_path)
if not hasattr(self, "highest_doc_id"):
if not hasattr(self, "highest_doc_id") or overwrite is True:
self.highest_doc_id = -1

if input_path.is_dir():
Expand Down Expand Up @@ -426,6 +435,25 @@ def _add_to_index(
import base64
import io

# Resize image while maintaining aspect ratio
if self.max_image_width and self.max_image_height:
img_width, img_height = image.size
aspect_ratio = img_width / img_height
if img_width > self.max_image_width:
new_width = self.max_image_width
new_height = int(new_width / aspect_ratio)
else:
new_width = img_width
new_height = img_height
if new_height > self.max_image_height:
new_height = self.max_image_height
new_width = int(new_height * aspect_ratio)
if self.verbose > 2:
print(f"Resizing image to {new_width}x{new_height}" ,
f"(aspect ratio {aspect_ratio:.2f}, original size {img_width}x{img_height},"
f"compression {new_width/img_width * new_height/img_height:.2f})")
image = image.resize((new_width, new_height), Image.LANCZOS)

buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
Expand Down

0 comments on commit 8d29b24

Please sign in to comment.