Skip to content

Commit

Permalink
[CHORE] Add endpoints to simulate rate-limiting on AWS S3 buckets (#1220
Browse files Browse the repository at this point in the history
)

# Summary

Adds FastAPI endpoints for integration tests, simulating rate-limiting
on AWS S3 buckets.

Drive-by: refactors our code to use FastAPI sub-apps instead of Routers.
This is needed to enable better control around rate limiting per-"app",
where each app corresponds to an S3 bucket with certain characteristics.

Closes: #1179 

More work needs to be done here once we support customizable retry
policies, so that we can verify that these retry strategies work.

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Aug 2, 2023
1 parent e681536 commit 60512a0
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 23 deletions.
15 changes: 10 additions & 5 deletions tests/integration/docker-compose/retry_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
We provide two different buckets, with slightly different behavior:
1. "head-retries-bucket": this bucket throws errors during HEAD operations
2. "get-retries-bucket": this bucket throws errors during the ranged GET operations
1. "head-retries-parquet-bucket": this bucket throws errors during HEAD operations
2. "get-retries-parquet-bucket": this bucket throws errors during the ranged GET operations
"""

from __future__ import annotations

from fastapi import FastAPI

from .routers import get_retries_bucket, head_retries_bucket
from .routers import (
get_retries_parquet_bucket,
head_retries_parquet_bucket,
rate_limited_echo_gets_bucket,
)

app = FastAPI()
app.include_router(get_retries_bucket.router)
app.include_router(head_retries_bucket.router)
app.mount(get_retries_parquet_bucket.route, get_retries_parquet_bucket.app)
app.mount(head_retries_parquet_bucket.route, head_retries_parquet_bucket.app)
app.mount(rate_limited_echo_gets_bucket.route, rate_limited_echo_gets_bucket.app)
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ uvicorn==0.23.2
uvloop==0.17.0
watchfiles==0.19.0
websockets==11.0.3
pyarrow
pyarrow==12.0.1
slowapi==0.1.8
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@
import os
from typing import Annotated

from fastapi import APIRouter, Header, Response
from fastapi import FastAPI, Header, Request, Response

from ..utils.parquet_generation import generate_parquet_file
from ..utils.responses import get_response

BUCKET_NAME = "get-retries-bucket"
BUCKET_NAME = "get-retries-parquet-bucket"
OBJECT_KEY_URL = "/{status_code}/{num_errors}/{item_id}"
MOCK_PARQUET_DATA_PATH = generate_parquet_file()

ITEM_ID_TO_NUM_RETRIES: dict[tuple[str, tuple[int, int]], int] = {}

router = APIRouter(prefix=f"/{BUCKET_NAME}")
route = f"/{BUCKET_NAME}"
app = FastAPI()


@router.head(OBJECT_KEY_URL)
@app.head(OBJECT_KEY_URL)
async def bucket_head(status_code: int, num_errors: int, item_id: str):
return Response(
headers={
Expand All @@ -28,8 +29,9 @@ async def bucket_head(status_code: int, num_errors: int, item_id: str):
)


@router.get(OBJECT_KEY_URL)
@app.get(OBJECT_KEY_URL)
async def retryable_bucket_get(
request: Request,
status_code: int,
num_errors: int,
item_id: str,
Expand All @@ -43,7 +45,7 @@ async def retryable_bucket_get(
else:
ITEM_ID_TO_NUM_RETRIES[key] += 1
if ITEM_ID_TO_NUM_RETRIES[key] <= num_errors:
return get_response(BUCKET_NAME, status_code, num_errors, item_id)
return get_response(request.url, status_code)

with open(MOCK_PARQUET_DATA_PATH.name, "rb") as f:
f.seek(start)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,31 @@
import os
from typing import Annotated

from fastapi import APIRouter, Header, Response
from fastapi import FastAPI, Header, Request, Response

from ..utils.parquet_generation import generate_parquet_file
from ..utils.responses import get_response

BUCKET_NAME = "head-retries-bucket"
BUCKET_NAME = "head-retries-parquet-bucket"
OBJECT_KEY_URL = "/{status_code}/{num_errors}/{item_id}"
MOCK_PARQUET_DATA_PATH = generate_parquet_file()

ITEM_ID_TO_NUM_RETRIES: dict[str, int] = {}

router = APIRouter(prefix=f"/{BUCKET_NAME}")
route = f"/{BUCKET_NAME}"
app = FastAPI()


@router.head(OBJECT_KEY_URL)
async def retryable_bucket_head(status_code: int, num_errors: int, item_id: str):
@app.head(OBJECT_KEY_URL)
async def retryable_bucket_head(request: Request, status_code: int, num_errors: int, item_id: str):
"""Reading of Parquet starts with a head request, which potentially must be retried as well"""
key = item_id
if key not in ITEM_ID_TO_NUM_RETRIES:
ITEM_ID_TO_NUM_RETRIES[key] = 1
else:
ITEM_ID_TO_NUM_RETRIES[key] += 1
if ITEM_ID_TO_NUM_RETRIES[key] <= num_errors:
return get_response(BUCKET_NAME, status_code, num_errors, item_id)
return get_response(request.url, status_code)

return Response(
headers={
Expand All @@ -37,7 +38,7 @@ async def retryable_bucket_head(status_code: int, num_errors: int, item_id: str)
)


@router.get(OBJECT_KEY_URL)
@app.get(OBJECT_KEY_URL)
async def bucket_get(
status_code: int,
num_errors: int,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from fastapi import FastAPI, Request, Response
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address

from ..utils.responses import get_response

BUCKET_NAME = "80-per-second-rate-limited-gets-bucket"
OBJECT_KEY_URL = "/{item_id}"

route = f"/{BUCKET_NAME}"
app = FastAPI()


def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> Response:
return get_response(request.url, status_code=503)


limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)


@app.get(OBJECT_KEY_URL)
@limiter.shared_limit(limit_value="80/second", scope="my_shared_limit")
async def rate_limited_bucket_get(request: Request, item_id: str):
"""This endpoint will just echo the `item_id` and return that as the response body"""
result = item_id.encode("utf-8")
return Response(
status_code=200,
content=result,
headers={
"Content-Length": str(len(result)),
"Content-Type": "binary/octet-stream",
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from fastapi import Response


def get_response(bucket_name: str, status_code: int, num_errors: int, item_id: str):
def get_response(url: str, status_code: int):
return Response(
status_code=status_code,
content=f"""<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code></Code>
<Message>This is a mock error message</Message>
<Resource>/{bucket_name}/{status_code}/{num_errors}/{item_id}</Resource>
<Resource>{url}</Resource>
<RequestId>4442587FB7D0A2F9</RequestId>
</Error>""",
headers={
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/io/parquet/test_reads_local_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from daft.table import Table

BUCKETS = ["head-retries-bucket", "get-retries-bucket"]
BUCKETS = ["head-retries-parquet-bucket", "get-retries-parquet-bucket"]


@pytest.mark.integration()
Expand All @@ -26,7 +26,7 @@ def test_non_retryable_errors(retry_server_s3_config, status_code: int, bucket:
500,
503,
504,
# TODO: We should also retry correctly on these error codes (PyArrow does retry appropriately here)
# TODO: [IO-RETRIES] We should also retry correctly on these error codes (PyArrow does retry appropriately here)
# These are marked as retryable error codes, see:
# https://github.com/aws/aws-sdk-cpp/blob/8a9550f1db04b33b3606602ba181d68377f763df/src/aws-cpp-sdk-core/include/aws/core/http/HttpResponse.h#L113-L131
# 509,
Expand All @@ -41,6 +41,7 @@ def test_non_retryable_errors(retry_server_s3_config, status_code: int, bucket:
@pytest.mark.parametrize("bucket", BUCKETS)
def test_retryable_errors(retry_server_s3_config, status_code: int, bucket: str):
# By default the SDK retries 3 times, so we should be able to tolerate NUM_ERRORS=2
# Tweak this variable appropriately to match the retry policy
NUM_ERRORS = 2
data_path = f"s3://{bucket}/{status_code}/{NUM_ERRORS}/{uuid.uuid4()}"

Expand Down
22 changes: 22 additions & 0 deletions tests/integration/io/test_url_download_s3_local_retry_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from __future__ import annotations

import pytest

import daft


@pytest.mark.integration()
@pytest.mark.skip(
reason="""[IO-RETRIES] This currently fails: we need better retry policies to have this work consistently.
Currently, if all the retries for a given URL happens to land in the same 1-second window, the request fails.
We should be able to get around this with a more generous retry policy, with larger increments between backoffs.
"""
)
def test_url_download_local_retry_server(retry_server_s3_config):
bucket = "80-per-second-rate-limited-gets-bucket"
data = {"urls": [f"s3://{bucket}/foo{i}" for i in range(100)]}
df = daft.from_pydict(data)
df = df.with_column(
"data", df["urls"].url.download(io_config=retry_server_s3_config, use_native_downloader=True, on_error="null")
)
assert df.to_pydict() == {**data, "data": [f"foo{i}".encode() for i in range(100)]}

0 comments on commit 60512a0

Please sign in to comment.