Skip to content

Commit

Permalink
Adding a FastAPI server (#60)
Browse files Browse the repository at this point in the history
* v1 API server

* Changing model_name to model

* Add a readme

* Add docstring

* Adding the server in the documentation

* Add server to norecursedirs, add doctstring for the endpoint method

* Enhancement of the documentation
  • Loading branch information
NohTow authored Oct 2, 2024
1 parent d526d98 commit 3f71cfc
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/documentation/.pages
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ nav:
- Datasets: datasets.md
- Retrieval: retrieval.md
- Evaluation: evaluation.md
- FastAPI: fastapi.md
22 changes: 22 additions & 0 deletions docs/documentation/fastapi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Serve the embeddings of a PyLate model using FastAPI
The ```server.py``` script (located in the ```server``` folder) allows to create a FastAPI server to serve the embeddings of a PyLate model.
To use it, you need to install the api dependencies: ```pip install "pylate[api]"```
Then, run ```python server.py``` to launch the server.

You can then send requests to the API like so:
```
curl -X POST http://localhost:8002/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"input": ["Query 1", "Query 2"],
"model": "lightonai/colbertv2.0",
"is_query": false
}'
```
If you want to encode queries, simply set ```ìs_query``` to ```True```.

???+ tip
Note that the server leverages [batched](https://github.com/mixedbread-ai/batched), so you can do batch processing by sending multiple separate calls and it will create batches dynamically to fill up the GPU.

For now, the server only support one loaded model, which you can define by using the ```--model``` argument when launching the server.

21 changes: 21 additions & 0 deletions pylate/server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Serve the embeddings of a PyLate model
The ```server.py``` script allows to create a FastAPI server to serve the embeddings of a PyLate model.
To use it, you need to install the api dependencies: ```pip install "pylate[api]"```
Then, run ```python server.py``` to launch the server.

You can then send requests to the API like so:
```
curl -X POST http://localhost:8002/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"input": ["Query 1", "Query 2"],
"model": "lightonai/colbertv2.0",
"is_query": false
}'
```
If you want to encode queries, simply set ```ìs_query``` to ```True```.

Note that the server leverages [batched](https://github.com/mixedbread-ai/batched), so you can do batch processing by sending multiple separate calls and it will create batches dynamically to fill up the GPU.

For now, the server only support one loaded model, which you can define by using the ```--model``` argument when launching the server.

130 changes: 130 additions & 0 deletions pylate/server/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import argparse
from typing import List

import batched
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from pylate import models

app = FastAPI()


class EmbeddingRequest(BaseModel):
"""PyDantic model for the requests sent to the server.
Parameters
----------
input
The input(s) to encode.
is_query
A boolean indicating if the input is a query or a document.
model
The name model to use for encoding.
"""

input: List[str] | str
is_query: bool = True
model: str = "lightonai/colbertv2.0"


class EmbeddingResponse(BaseModel):
"""PyDantic model for the server answer to a call.
Parameters
----------
data
A list of dictionaries containing the embeddings ("embedding" key) and the type of the object ("object" key, is always embedding).
model
The name of the model used for encoding.
usage
An approximation of the number of tokens used to generate the embeddings (computed by splitting the input sequences on spaces).
"""

data: List[dict]
model: str
usage: dict


def wrap_encode_function(model, **kwargs):
def wrapped_encode(sentences):
return model.encode(sentences, **kwargs)

return wrapped_encode


def parse_args():
parser = argparse.ArgumentParser(
description="Run FastAPI ColBERT serving server with specified host, port, and model."
)
parser.add_argument(
"--host", type=str, default="0.0.0.0", help="Host to run the server on"
)
parser.add_argument(
"--port", type=int, default=8002, help="Port to run the server on"
)
parser.add_argument(
"--model",
type=str,
default="lightonai/colbertv2.0",
help="Model to serve, can be an HF model or a path to a model",
)
return parser.parse_args()


args = parse_args()

# We need to load the model here so it is shared for every request
model = models.ColBERT(args.model)
# We cannot create the function on the fly as the batching require to use the same function (memory address)
model.encode_query = batched.aio.dynamically(wrap_encode_function(model, is_query=True))
model.encode_document = batched.aio.dynamically(
wrap_encode_function(model, is_query=False)
)


@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embedding(request: EmbeddingRequest):
"""API endpoint that encode the elements of an EmbeddingRequest and returns an EmbeddingResponse.
Parameters
----------
request
The EmbeddingRequest containing the elements to encode, the model to use, and whether the input is a query or a document.
"""
if request.model != args.model:
raise HTTPException(
status_code=400,
detail=f"Model not supported, the loaded model is {args.model}, but the request is for {request.model}",
)
try:
if request.is_query:
embeddings = await model.encode_query(
request.input,
)
else:
embeddings = await model.encode_document(
request.input,
)

# Format response
data = [
{"object": "embedding", "embedding": embedding.tolist(), "index": i}
for i, embedding in enumerate(embeddings)
]

# Calculate token usage (approximate)
total_tokens = sum(len(text.split()) for text in request.input)

return EmbeddingResponse(
data=data,
model=request.model,
usage={"prompt_tokens": total_tokens, "total_tokens": total_tokens},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


if __name__ == "__main__":
uvicorn.run(app, host=args.host, port=args.port)
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ norecursedirs =
build
docs
node_modules
pylate/server
markers =
web: tests that require using the Internet
slow: tests that take a long time to run
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

eval = ["ranx >= 0.3.16", "beir >= 2.0.0"]

api = ["fastapi >= 0.114.1", "uvicorn >= 0.30.6", "batched >= 0.1.2"]

setuptools.setup(
name="pylate",
Expand All @@ -44,7 +45,8 @@
install_requires=base_packages,
extras_require={
"eval": base_packages + eval,
"dev": base_packages + dev + eval,
"api": base_packages + api,
"dev": base_packages + dev + eval + api,
},
classifiers=[
"Programming Language :: Python :: 3",
Expand Down

0 comments on commit 3f71cfc

Please sign in to comment.