Skip to content

Commit

Permalink
feat: add guardrails and instructions in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tonywu71 committed Sep 17, 2024
1 parent cd011c8 commit f550289
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/all.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from pathlib import Path

from colpali_engine.utils.torch_utils import get_torch_device

from byaldi import RAGMultiModalModel

device = get_torch_device("auto")
print(f"Using device: {device}")

path_document_1 = Path("docs/attention.pdf")
path_document_2 = Path("docs/attention_copy.pdf")


def test_single_pdf():
print("Testing single PDF indexing and retrieval...")

# Initialize the model
model = RAGMultiModalModel.from_pretrained("vidore/colpali-v1.2", device=device)

if not Path("docs/attention.pdf").is_file():
raise FileNotFoundError(
f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_1}."
)

# Index a single PDF
model.index(
input_path="docs/attention.pdf",
Expand Down Expand Up @@ -56,6 +66,15 @@ def test_multi_document():
# Initialize the model
model = RAGMultiModalModel.from_pretrained("vidore/colpali")

if not Path("docs/attention.pdf").is_file():
raise FileNotFoundError(
f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_1}."
)
if not Path("docs/attention_copy.pdf").is_file():
raise FileNotFoundError(
f"Please download the PDF file from https://arxiv.org/pdf/1706.03762 and move it to {path_document_2}."
)

# Index a directory of documents
model.index(
input_path="docs/",
Expand Down Expand Up @@ -137,6 +156,15 @@ def test_add_to_index():


if __name__ == "__main__":
print("Starting tests...")

print("/n/n----------------- Single PDF test -----------------n")
test_single_pdf()

print("/n/n----------------- Multi document test -----------------n")
test_multi_document()

print("/n/n----------------- Add to index test -----------------n")
test_add_to_index()

print("All tests completed.")

0 comments on commit f550289

Please sign in to comment.