diff --git a/vector-serve/app/routes/transform.py b/vector-serve/app/routes/transform.py index 8b5527d..aec5e5b 100644 --- a/vector-serve/app/routes/transform.py +++ b/vector-serve/app/routes/transform.py @@ -3,9 +3,11 @@ from typing import TYPE_CHECKING, Any, List from app.models import model_org_name, get_model, parse_header +from app.utils.chunking import recursive_text_chunk from fastapi import APIRouter, Header, HTTPException, Request from pydantic import BaseModel, conlist + router = APIRouter(tags=["transform"]) logging.basicConfig(level=logging.DEBUG) @@ -40,6 +42,13 @@ def batch_transform( request: Request, payload: Batch, authorization: str = Header(None) ) -> ResponseModel: logging.info({"batch-predict-len": len(payload.input)}) + + chunked_input = [] + for doc in payload.input: + chunked_input.extend( + recursive_text_chunk(doc, chunk_size=1000, chunk_overlap=200) + ) + batches = chunk_list(payload.input, BATCH_SIZE) num_batches = len(batches) responses: list[list[float]] = [] diff --git a/vector-serve/app/utils/chunking.py b/vector-serve/app/utils/chunking.py new file mode 100644 index 0000000..3dcc781 --- /dev/null +++ b/vector-serve/app/utils/chunking.py @@ -0,0 +1,32 @@ +from typing import List + + +def recursive_text_chunk( + text: str, + chunk_size: int = 1000, + chunk_overlap: int = 200, + separators: List[str] = ["\n\n", "\n", " ", ""], +) -> List[str]: + """Recursively splits text into smaller chunks with overlap.""" + + chunks = [] + current_position = 0 + + while current_position < len(text): + next_chunk = None + for separator in separators: + next_split = text.rfind( + separator, current_position, current_position + chunk_size + ) + if next_split != -1: + next_chunk = text[current_position:next_split].strip() + current_position = next_split + len(separator) - chunk_overlap + break + + if not next_chunk: + next_chunk = text[current_position : current_position + chunk_size].strip() + current_position += chunk_size - chunk_overlap + + chunks.append(next_chunk) + + return chunks diff --git a/vector-serve/tests/test_endpoints.py b/vector-serve/tests/test_endpoints.py index e1f733c..05d7a2c 100644 --- a/vector-serve/tests/test_endpoints.py +++ b/vector-serve/tests/test_endpoints.py @@ -1,18 +1,23 @@ from fastapi.testclient import TestClient from fastapi import FastAPI + def test_ready_endpoint(test_client): response = test_client.get("/ready") assert response.status_code == 200 assert response.json() == {"ready": True} + def test_alive_endpoint(test_client): response = test_client.get("/alive") assert response.status_code == 200 assert response.json() == {"alive": True} + def test_model_info(test_client): - response = test_client.get("/v1/info", params={"model_name": "sentence-transformers/all-MiniLM-L6-v2"}) + response = test_client.get( + "/v1/info", params={"model_name": "sentence-transformers/all-MiniLM-L6-v2"} + ) assert response.status_code == 200 @@ -20,3 +25,57 @@ def test_metrics_endpoint(test_client): response = test_client.get("/metrics") assert response.status_code == 200 assert "all-MiniLM-L6-v2" in response.text + + +# Simulate a large document +long_text = "This is a very long document. " * 1000 + + +def test_chunking_basic(test_client): + payload = {"input": [long_text], "model": "all-MiniLM-L6-v2", "normalize": False} + response = test_client.post("/v1/embeddings", json=payload) + + assert response.status_code == 200 + response_data = response.json() + + assert len(response_data["data"]) > 0 + assert "embedding" in response_data["data"][0] + assert len(response_data["data"]) > 1 + + +def test_chunking_small_input(test_client): + small_text = "Short text." + payload = {"input": [small_text], "model": "all-MiniLM-L6-v2", "normalize": False} + response = test_client.post("/v1/embeddings", json=payload) + + assert response.status_code == 200 + response_data = response.json() + + assert len(response_data["data"]) == 1 + assert "embedding" in response_data["data"][0] + + +def test_chunk_overlap(test_client): + payload = {"input": [long_text], "model": "all-MiniLM-L6-v2", "normalize": False} + response = test_client.post("/v1/embeddings", json=payload) + + assert response.status_code == 200 + response_data = response.json() + + chunk_size = 1000 + overlap_size = 200 + num_chunks = len(long_text) // (chunk_size - overlap_size) + + assert len(response_data["data"]) == num_chunks + + +def test_large_input(test_client): + large_text = "Lorem ipsum " * 5000 + payload = {"input": [large_text], "model": "all-MiniLM-L6-v2", "normalize": False} + response = test_client.post("/v1/embeddings", json=payload) + + assert response.status_code == 200 + response_data = response.json() + + assert len(response_data["data"]) > 1 + assert "embedding" in response_data["data"][0]