Skip to content

Commit

Permalink
feat: add chunking support to vectorize.table()
Browse files Browse the repository at this point in the history
  • Loading branch information
asr2003 authored Oct 17, 2024
1 parent 97b2419 commit eca8d23
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
9 changes: 9 additions & 0 deletions vector-serve/app/routes/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import TYPE_CHECKING, Any, List

from app.models import model_org_name, get_model, parse_header
from app.utils.chunking import recursive_text_chunk
from fastapi import APIRouter, Header, HTTPException, Request
from pydantic import BaseModel, conlist


router = APIRouter(tags=["transform"])

logging.basicConfig(level=logging.DEBUG)
Expand Down Expand Up @@ -40,6 +42,13 @@ def batch_transform(
request: Request, payload: Batch, authorization: str = Header(None)
) -> ResponseModel:
logging.info({"batch-predict-len": len(payload.input)})

chunked_input = []
for doc in payload.input:
chunked_input.extend(
recursive_text_chunk(doc, chunk_size=1000, chunk_overlap=200)
)

batches = chunk_list(payload.input, BATCH_SIZE)
num_batches = len(batches)
responses: list[list[float]] = []
Expand Down
32 changes: 32 additions & 0 deletions vector-serve/app/utils/chunking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List


def recursive_text_chunk(
text: str,
chunk_size: int = 1000,
chunk_overlap: int = 200,
separators: List[str] = ["\n\n", "\n", " ", ""],
) -> List[str]:
"""Recursively splits text into smaller chunks with overlap."""

chunks = []
current_position = 0

while current_position < len(text):
next_chunk = None
for separator in separators:
next_split = text.rfind(
separator, current_position, current_position + chunk_size
)
if next_split != -1:
next_chunk = text[current_position:next_split].strip()
current_position = next_split + len(separator) - chunk_overlap
break

if not next_chunk:
next_chunk = text[current_position : current_position + chunk_size].strip()
current_position += chunk_size - chunk_overlap

chunks.append(next_chunk)

return chunks
61 changes: 60 additions & 1 deletion vector-serve/tests/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,81 @@
from fastapi.testclient import TestClient
from fastapi import FastAPI


def test_ready_endpoint(test_client):
response = test_client.get("/ready")
assert response.status_code == 200
assert response.json() == {"ready": True}


def test_alive_endpoint(test_client):
response = test_client.get("/alive")
assert response.status_code == 200
assert response.json() == {"alive": True}


def test_model_info(test_client):
response = test_client.get("/v1/info", params={"model_name": "sentence-transformers/all-MiniLM-L6-v2"})
response = test_client.get(
"/v1/info", params={"model_name": "sentence-transformers/all-MiniLM-L6-v2"}
)
assert response.status_code == 200


def test_metrics_endpoint(test_client):
response = test_client.get("/metrics")
assert response.status_code == 200
assert "all-MiniLM-L6-v2" in response.text


# Simulate a large document
long_text = "This is a very long document. " * 1000


def test_chunking_basic(test_client):
payload = {"input": [long_text], "model": "all-MiniLM-L6-v2", "normalize": False}
response = test_client.post("/v1/embeddings", json=payload)

assert response.status_code == 200
response_data = response.json()

assert len(response_data["data"]) > 0
assert "embedding" in response_data["data"][0]
assert len(response_data["data"]) > 1


def test_chunking_small_input(test_client):
small_text = "Short text."
payload = {"input": [small_text], "model": "all-MiniLM-L6-v2", "normalize": False}
response = test_client.post("/v1/embeddings", json=payload)

assert response.status_code == 200
response_data = response.json()

assert len(response_data["data"]) == 1
assert "embedding" in response_data["data"][0]


def test_chunk_overlap(test_client):
payload = {"input": [long_text], "model": "all-MiniLM-L6-v2", "normalize": False}
response = test_client.post("/v1/embeddings", json=payload)

assert response.status_code == 200
response_data = response.json()

chunk_size = 1000
overlap_size = 200
num_chunks = len(long_text) // (chunk_size - overlap_size)

assert len(response_data["data"]) == num_chunks


def test_large_input(test_client):
large_text = "Lorem ipsum " * 5000
payload = {"input": [large_text], "model": "all-MiniLM-L6-v2", "normalize": False}
response = test_client.post("/v1/embeddings", json=payload)

assert response.status_code == 200
response_data = response.json()

assert len(response_data["data"]) > 1
assert "embedding" in response_data["data"][0]

0 comments on commit eca8d23

Please sign in to comment.