diff --git a/docs/documentation/.pages b/docs/documentation/.pages index 878d647..5aa0204 100644 --- a/docs/documentation/.pages +++ b/docs/documentation/.pages @@ -4,3 +4,4 @@ nav: - Datasets: datasets.md - Retrieval: retrieval.md - Evaluation: evaluation.md + - FastAPI: fastapi.md diff --git a/docs/documentation/fastapi.md b/docs/documentation/fastapi.md new file mode 100644 index 0000000..9649b68 --- /dev/null +++ b/docs/documentation/fastapi.md @@ -0,0 +1,22 @@ +# Serve the embeddings of a PyLate model using FastAPI +The ```server.py``` script (located in the ```server``` folder) allows to create a FastAPI server to serve the embeddings of a PyLate model. +To use it, you need to install the api dependencies: ```pip install "pylate[api]"``` +Then, run ```python server.py``` to launch the server. + +You can then send requests to the API like so: +``` +curl -X POST http://localhost:8002/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "input": ["Query 1", "Query 2"], + "model": "lightonai/colbertv2.0", + "is_query": false + }' +``` +If you want to encode queries, simply set ```ìs_query``` to ```True```. + +???+ tip + Note that the server leverages [batched](https://github.com/mixedbread-ai/batched), so you can do batch processing by sending multiple separate calls and it will create batches dynamically to fill up the GPU. + +For now, the server only support one loaded model, which you can define by using the ```--model``` argument when launching the server. + diff --git a/pylate/server/README.md b/pylate/server/README.md new file mode 100644 index 0000000..ef5c8a2 --- /dev/null +++ b/pylate/server/README.md @@ -0,0 +1,21 @@ +# Serve the embeddings of a PyLate model +The ```server.py``` script allows to create a FastAPI server to serve the embeddings of a PyLate model. +To use it, you need to install the api dependencies: ```pip install "pylate[api]"``` +Then, run ```python server.py``` to launch the server. + +You can then send requests to the API like so: +``` +curl -X POST http://localhost:8002/v1/embeddings \ + -H "Content-Type: application/json" \ + -d '{ + "input": ["Query 1", "Query 2"], + "model": "lightonai/colbertv2.0", + "is_query": false + }' +``` +If you want to encode queries, simply set ```ìs_query``` to ```True```. + +Note that the server leverages [batched](https://github.com/mixedbread-ai/batched), so you can do batch processing by sending multiple separate calls and it will create batches dynamically to fill up the GPU. + +For now, the server only support one loaded model, which you can define by using the ```--model``` argument when launching the server. + diff --git a/pylate/server/server.py b/pylate/server/server.py new file mode 100644 index 0000000..8795783 --- /dev/null +++ b/pylate/server/server.py @@ -0,0 +1,130 @@ +import argparse +from typing import List + +import batched +import uvicorn +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +from pylate import models + +app = FastAPI() + + +class EmbeddingRequest(BaseModel): + """PyDantic model for the requests sent to the server. + + Parameters + ---------- + input + The input(s) to encode. + is_query + A boolean indicating if the input is a query or a document. + model + The name model to use for encoding. + """ + + input: List[str] | str + is_query: bool = True + model: str = "lightonai/colbertv2.0" + + +class EmbeddingResponse(BaseModel): + """PyDantic model for the server answer to a call. + + Parameters + ---------- + data + A list of dictionaries containing the embeddings ("embedding" key) and the type of the object ("object" key, is always embedding). + model + The name of the model used for encoding. + usage + An approximation of the number of tokens used to generate the embeddings (computed by splitting the input sequences on spaces). + """ + + data: List[dict] + model: str + usage: dict + + +def wrap_encode_function(model, **kwargs): + def wrapped_encode(sentences): + return model.encode(sentences, **kwargs) + + return wrapped_encode + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run FastAPI ColBERT serving server with specified host, port, and model." + ) + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to run the server on" + ) + parser.add_argument( + "--port", type=int, default=8002, help="Port to run the server on" + ) + parser.add_argument( + "--model", + type=str, + default="lightonai/colbertv2.0", + help="Model to serve, can be an HF model or a path to a model", + ) + return parser.parse_args() + + +args = parse_args() + +# We need to load the model here so it is shared for every request +model = models.ColBERT(args.model) +# We cannot create the function on the fly as the batching require to use the same function (memory address) +model.encode_query = batched.aio.dynamically(wrap_encode_function(model, is_query=True)) +model.encode_document = batched.aio.dynamically( + wrap_encode_function(model, is_query=False) +) + + +@app.post("/v1/embeddings", response_model=EmbeddingResponse) +async def create_embedding(request: EmbeddingRequest): + """API endpoint that encode the elements of an EmbeddingRequest and returns an EmbeddingResponse. + + Parameters + ---------- + request + The EmbeddingRequest containing the elements to encode, the model to use, and whether the input is a query or a document. + """ + if request.model != args.model: + raise HTTPException( + status_code=400, + detail=f"Model not supported, the loaded model is {args.model}, but the request is for {request.model}", + ) + try: + if request.is_query: + embeddings = await model.encode_query( + request.input, + ) + else: + embeddings = await model.encode_document( + request.input, + ) + + # Format response + data = [ + {"object": "embedding", "embedding": embedding.tolist(), "index": i} + for i, embedding in enumerate(embeddings) + ] + + # Calculate token usage (approximate) + total_tokens = sum(len(text.split()) for text in request.input) + + return EmbeddingResponse( + data=data, + model=request.model, + usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens}, + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +if __name__ == "__main__": + uvicorn.run(app, host=args.host, port=args.port) diff --git a/pytest.ini b/pytest.ini index 500d872..bfa94a2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -14,6 +14,7 @@ norecursedirs = build docs node_modules + pylate/server markers = web: tests that require using the Internet slow: tests that take a long time to run \ No newline at end of file diff --git a/setup.py b/setup.py index 9d76936..278a4f4 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ eval = ["ranx >= 0.3.16", "beir >= 2.0.0"] +api = ["fastapi >= 0.114.1", "uvicorn >= 0.30.6", "batched >= 0.1.2"] setuptools.setup( name="pylate", @@ -44,7 +45,8 @@ install_requires=base_packages, extras_require={ "eval": base_packages + eval, - "dev": base_packages + dev + eval, + "api": base_packages + api, + "dev": base_packages + dev + eval + api, }, classifiers=[ "Programming Language :: Python :: 3",