Skip to content

Commit

Permalink
Merge pull request #1065 from lsst/tickets/DM-45908
Browse files Browse the repository at this point in the history
DM-45908: Fix client-side HTTP timeouts when communicating with Butler server
  • Loading branch information
dhirving authored Aug 30, 2024
2 parents bc1af40 + 7efb63b commit 3ea58f3
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 85 deletions.
13 changes: 12 additions & 1 deletion python/lsst/daf/butler/remote_butler/_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,18 @@ def __init__(self, server_url: str, http_client: httpx.Client | None = None):
if http_client is not None:
self.http_client = http_client
else:
self.http_client = httpx.Client()
self.http_client = httpx.Client(
# This timeout is fairly conservative. This value isn't the
# maximum amount of time the request can take -- it's the
# maximum amount of time to wait after receiving the last chunk
# of data from the server.
#
# Long-running, streamed queries send a keep-alive every 15
# seconds. However, unstreamed operations like
# queryCollections can potentially take a while if the database
# is under duress.
timeout=120 # seconds
)
self._cache = RemoteButlerCache()

@staticmethod
Expand Down
14 changes: 12 additions & 2 deletions python/lsst/daf/butler/remote_butler/_query_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,11 @@ def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPa
# There is one result page JSON object per line of the
# response.
for line in response.iter_lines():
result_chunk = _QueryResultTypeAdapter.validate_json(line)
yield _convert_query_result_page(result_spec, result_chunk, universe)
result_chunk: QueryExecuteResultData = _QueryResultTypeAdapter.validate_json(line)
if result_chunk.type == "keep-alive":
_received_keep_alive()
else:
yield _convert_query_result_page(result_spec, result_chunk, universe)
if self._closed:
raise RuntimeError(
"Cannot continue query result iteration: query context has been closed"
Expand Down Expand Up @@ -279,3 +282,10 @@ def _convert_general_result(spec: GeneralResultSpec, model: GeneralResultModel)
for row in model.rows
]
return GeneralResultPage(spec=spec, rows=rows)


def _received_keep_alive() -> None:
"""Do nothing. Gives a place for unit tests to hook in for testing
keepalive behavior.
"""
pass
2 changes: 0 additions & 2 deletions python/lsst/daf/butler/remote_butler/server/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import safir.dependencies.logger
from fastapi import FastAPI, Request, Response
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.staticfiles import StaticFiles
from safir.logging import configure_logging, configure_uvicorn_logging

Expand All @@ -54,7 +53,6 @@ def create_app() -> FastAPI:
config = load_config()

app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)

# A single instance of the server can serve data from multiple Butler
# repositories. This 'repository' path placeholder is consumed by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@

__all__ = ("query_router",)

from collections.abc import AsyncIterator, Iterable, Iterator
from contextlib import ExitStack, contextmanager
import asyncio
from collections.abc import AsyncIterator, Iterator
from contextlib import contextmanager
from typing import NamedTuple

from fastapi import APIRouter, Depends
Expand All @@ -42,72 +43,112 @@
QueryAnyResponseModel,
QueryCountRequestModel,
QueryCountResponseModel,
QueryErrorResultModel,
QueryExecuteRequestModel,
QueryExecuteResultData,
QueryExplainRequestModel,
QueryExplainResponseModel,
QueryInputs,
QueryKeepAliveModel,
)

from ....queries.driver import QueryDriver, QueryTree, ResultPage, ResultSpec
from ...._exceptions import ButlerUserError
from ....queries.driver import QueryDriver, QueryTree, ResultSpec
from ..._errors import serialize_butler_user_error
from .._dependencies import factory_dependency
from .._factory import Factory
from ._query_serialization import serialize_query_pages
from ._query_serialization import convert_query_page

query_router = APIRouter()

# Alias this function so we can mock it during unit tests.
_timeout = asyncio.timeout


@query_router.post("/v1/query/execute", summary="Query the Butler database and return full results")
def query_execute(
async def query_execute(
request: QueryExecuteRequestModel, factory: Factory = Depends(factory_dependency)
) -> StreamingResponse:
# Managing the lifetime of the query context object is a little tricky. We
# need to enter the context here, so that we can immediately deal with any
# exceptions raised by query set-up. We eventually transfer control to an
# iterator consumed by FastAPI's StreamingResponse handler, which will
# start iterating after this function returns. So we use this ExitStack
# instance to hand over the context manager to the iterator.
with ExitStack() as exit_stack:
ctx = exit_stack.enter_context(_get_query_context(factory, request.query))
spec = request.result_spec.to_result_spec(ctx.driver.universe)
response_pages = ctx.driver.execute(spec, ctx.tree)

# We write the response incrementally, one page at a time, as
# newline-separated chunks of JSON. This allows clients to start
# reading results earlier and prevents the server from exhausting
# all its memory buffering rows from large queries.
output_generator = _stream_query_pages(
# Transfer control of the context manager to
# _stream_query_pages.
exit_stack.pop_all(),
spec,
response_pages,
)
return StreamingResponse(output_generator, media_type="application/jsonlines")

# Mypy thinks that ExitStack might swallow an exception.
assert False, "This line is unreachable."


async def _stream_query_pages(
exit_stack: ExitStack, spec: ResultSpec, pages: Iterable[ResultPage]
) -> AsyncIterator[str]:
# Instead of declaring this as a sync generator with 'def', it's async to
# give us more control over the lifetime of exit_stack. StreamingResponse
# ensures that this async generator is cancelled if the client
# disconnects or another error occurs, ensuring that clean-up logic runs.
#
# If it was sync, it would get wrapped in an async function internal to
# FastAPI that does not guarantee that the generator is fully iterated or
# closed.
# (There is an example in the FastAPI docs showing StreamingResponse with a
# sync generator with a context manager, but after reading the FastAPI
# source code I believe that for sync generators it will leak the context
# manager if the client disconnects, and that it would be
# difficult/impossible for them to fix this in the general case within
# FastAPI.)
async with contextmanager_in_threadpool(exit_stack):
async for chunk in iterate_in_threadpool(serialize_query_pages(spec, pages)):
yield chunk
# We write the response incrementally, one page at a time, as
# newline-separated chunks of JSON. This allows clients to start
# reading results earlier and prevents the server from exhausting
# all its memory buffering rows from large queries.
output_generator = _stream_query_pages(request, factory)
return StreamingResponse(
output_generator,
media_type="application/jsonlines",
headers={
# Instruct the Kubernetes ingress to not buffer the response,
# so that keep-alives reach the client promptly.
"X-Accel-Buffering": "no"
},
)


async def _stream_query_pages(request: QueryExecuteRequestModel, factory: Factory) -> AsyncIterator[str]:
"""Stream the query output with one page object per line, as
newline-delimited JSON records in the "JSON Lines" format
(https://jsonlines.org/).
When it takes longer than 15 seconds to get a response from the DB,
sends a keep-alive message to prevent clients from timing out.
"""
# `None` signals that there is no more data to send.
queue = asyncio.Queue[QueryExecuteResultData | None](1)
async with asyncio.TaskGroup() as tg:
# Run a background task to read from the DB and insert the result pages
# into a queue.
tg.create_task(_enqueue_query_pages(queue, request, factory))
# Read the result pages from the queue and send them to the client,
# inserting a keep-alive message every 15 seconds if we are waiting a
# long time for the database.
async for message in _dequeue_query_pages_with_keepalive(queue):
yield message.model_dump_json() + "\n"


async def _enqueue_query_pages(
queue: asyncio.Queue[QueryExecuteResultData | None], request: QueryExecuteRequestModel, factory: Factory
) -> None:
"""Set up a QueryDriver to run the query, and copy the results into a
queue. Send `None` to the queue when there is no more data to read.
"""
try:
async with contextmanager_in_threadpool(_get_query_context(factory, request.query)) as ctx:
spec = request.result_spec.to_result_spec(ctx.driver.universe)
async for page in iterate_in_threadpool(_retrieve_query_pages(ctx, spec)):
await queue.put(page)
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
await queue.put(QueryErrorResultModel(error=serialize_butler_user_error(e)))

# Signal that there is no more data to read.
await queue.put(None)


def _retrieve_query_pages(ctx: _QueryContext, spec: ResultSpec) -> Iterator[QueryExecuteResultData]:
"""Execute the database query and and return pages of results."""
pages = ctx.driver.execute(spec, ctx.tree)
for page in pages:
yield convert_query_page(spec, page)


async def _dequeue_query_pages_with_keepalive(
queue: asyncio.Queue[QueryExecuteResultData | None],
) -> AsyncIterator[QueryExecuteResultData]:
"""Read and return messages from the given queue until the end-of-stream
message `None` is reached. If the producer is taking a long time, returns
a keep-alive message every 15 seconds while we are waiting.
"""
while True:
try:
async with _timeout(15):
message = await queue.get()
if message is None:
return
yield message
except TimeoutError:
yield QueryKeepAliveModel()


@query_router.post(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@

from __future__ import annotations

from collections.abc import Iterable, Iterator

from ...._exceptions import ButlerUserError
from ....queries.driver import (
DataCoordinateResultPage,
DatasetRefResultPage,
Expand All @@ -38,45 +35,30 @@
ResultPage,
ResultSpec,
)
from ..._errors import serialize_butler_user_error
from ...server_models import (
DataCoordinateResultModel,
DatasetRefResultModel,
DimensionRecordsResultModel,
GeneralResultModel,
QueryErrorResultModel,
QueryExecuteResultData,
)


def serialize_query_pages(
spec: ResultSpec, pages: Iterable[ResultPage]
) -> Iterator[str]: # numpydoc ignore=PR01
"""Serialize result pages to pages of result data in JSON format. The
output contains one page object per line, as newline-delimited JSON records
in the "JSON Lines" format (https://jsonlines.org/).
"""
try:
for page in pages:
yield _convert_query_page(spec, page).model_dump_json()
yield "\n"
except ButlerUserError as e:
# If a user-facing error occurs, serialize it and send it to the
# client.
yield QueryErrorResultModel(error=serialize_butler_user_error(e)).model_dump_json()
yield "\n"


def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData:
def convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData:
"""Convert pages of result data from the query system to a serializable
format.
Parameters
----------
spec : `ResultSpec`
Definition of the output format for the results.
pages : `ResultPage`
page : `ResultPage`
Raw page of data from the query driver.
Returns
-------
model : `QueryExecuteResultData`
Serializable pydantic model version of the page.
"""
match spec.result_type:
case "dimension_record":
Expand Down
14 changes: 13 additions & 1 deletion python/lsst/daf/butler/remote_butler/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,24 @@ class QueryErrorResultModel(pydantic.BaseModel):
error: ErrorResponseModel


class QueryKeepAliveModel(pydantic.BaseModel):
"""Result model for /query/execute used to keep connection alive.
Some queries require a significant start-up time before they can start
returning results, or a long processing time for each chunk of rows. This
message signals that the server is still fetching the data.
"""

type: Literal["keep-alive"] = "keep-alive"


QueryExecuteResultData: TypeAlias = Annotated[
DataCoordinateResultModel
| DimensionRecordsResultModel
| DatasetRefResultModel
| GeneralResultModel
| QueryErrorResultModel,
| QueryErrorResultModel
| QueryKeepAliveModel,
pydantic.Field(discriminator="type"),
]

Expand Down
42 changes: 41 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

try:
# Failing to import any of these should disable the tests.
import lsst.daf.butler.remote_butler._query_driver
import lsst.daf.butler.remote_butler.server.handlers._external_query
import safir.dependencies.logger
from fastapi.testclient import TestClient
from lsst.daf.butler.remote_butler import RemoteButler
Expand All @@ -47,7 +49,7 @@
create_test_server = None
reason_text = str(e)

from unittest.mock import NonCallableMock, patch
from unittest.mock import DEFAULT, NonCallableMock, patch

from lsst.daf.butler import (
Butler,
Expand Down Expand Up @@ -402,6 +404,28 @@ async def get_logger():
self.assertEqual(kwargs["clientRequestId"], "request-id")
self.assertEqual(kwargs["user"], "user-name")

def test_query_keepalive(self):
"""Test that long-running queries stream keep-alive messages to stop
the HTTP connection from closing before they are able to return
results.
"""
# Normally it takes 15 seconds for a timeout -- mock it to trigger
# immediately instead.
with patch.object(
lsst.daf.butler.remote_butler.server.handlers._external_query, "_timeout"
) as mock_timeout:
# Hook into QueryDriver to track the number of keep-alives we have
# seen.
with patch.object(
lsst.daf.butler.remote_butler._query_driver, "_received_keep_alive"
) as mock_keep_alive:
mock_timeout.side_effect = _timeout_twice()
with self.butler._query() as query:
datasets = list(query.datasets("bias", "imported_g"))
self.assertEqual(len(datasets), 3)
self.assertGreaterEqual(mock_timeout.call_count, 3)
self.assertGreaterEqual(mock_keep_alive.call_count, 2)


def _create_corrupted_dataset(repo: MetricTestRepo) -> DatasetRef:
run = "corrupted-run"
Expand All @@ -418,5 +442,21 @@ def _create_simple_dataset(butler: Butler) -> DatasetRef:
return ref


def _timeout_twice():
"""Return a mock side-effect function that raises a timeout error the first
two times it is called.
"""
count = 0

def timeout(*args):
nonlocal count
count += 1
if count <= 2:
raise TimeoutError()
return DEFAULT

return timeout


if __name__ == "__main__":
unittest.main()

0 comments on commit 3ea58f3

Please sign in to comment.