diff --git a/python/lsst/daf/butler/remote_butler/_query_driver.py b/python/lsst/daf/butler/remote_butler/_query_driver.py index 5cd2d0a793..d290e62ffe 100644 --- a/python/lsst/daf/butler/remote_butler/_query_driver.py +++ b/python/lsst/daf/butler/remote_butler/_query_driver.py @@ -147,8 +147,9 @@ def execute(self, result_spec: ResultSpec, tree: QueryTree) -> Iterator[ResultPa # There is one result page JSON object per line of the # response. for line in response.iter_lines(): - result_chunk = _QueryResultTypeAdapter.validate_json(line) - yield _convert_query_result_page(result_spec, result_chunk, universe) + result_chunk: QueryExecuteResultData = _QueryResultTypeAdapter.validate_json(line) + if result_chunk.type != "keep-alive": + yield _convert_query_result_page(result_spec, result_chunk, universe) if self._closed: raise RuntimeError( "Cannot continue query result iteration: query context has been closed" diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py index 909088e198..b5b73e04b6 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_external_query.py @@ -29,6 +29,7 @@ __all__ = ("query_router",) +import asyncio from collections.abc import AsyncIterator, Iterable, Iterator from contextlib import ExitStack, contextmanager from typing import NamedTuple @@ -46,6 +47,7 @@ QueryExplainRequestModel, QueryExplainResponseModel, QueryInputs, + QueryKeepAliveModel, ) from ....queries.driver import QueryDriver, QueryTree, ResultPage, ResultSpec @@ -88,26 +90,70 @@ def query_execute( assert False, "This line is unreachable." +# Instead of declaring this as a sync generator with 'def', it's async to +# give us more control over the lifetime of exit_stack. StreamingResponse +# ensures that this async generator is cancelled if the client +# disconnects or another error occurs, ensuring that clean-up logic runs. +# +# If it was sync, it would get wrapped in an async function internal to +# FastAPI that does not guarantee that the generator is fully iterated or +# closed. +# (There is an example in the FastAPI docs showing StreamingResponse with a +# sync generator with a context manager, but after reading the FastAPI +# source code I believe that for sync generators it will leak the context +# manager if the client disconnects, and that it would be +# difficult/impossible for them to fix this in the general case within +# FastAPI.) async def _stream_query_pages( exit_stack: ExitStack, spec: ResultSpec, pages: Iterable[ResultPage] ) -> AsyncIterator[str]: - # Instead of declaring this as a sync generator with 'def', it's async to - # give us more control over the lifetime of exit_stack. StreamingResponse - # ensures that this async generator is cancelled if the client - # disconnects or another error occurs, ensuring that clean-up logic runs. - # - # If it was sync, it would get wrapped in an async function internal to - # FastAPI that does not guarantee that the generator is fully iterated or - # closed. - # (There is an example in the FastAPI docs showing StreamingResponse with a - # sync generator with a context manager, but after reading the FastAPI - # source code I believe that for sync generators it will leak the context - # manager if the client disconnects, and that it would be - # difficult/impossible for them to fix this in the general case within - # FastAPI.) + """Stream the query output with one page object per line, as + newline-delimited JSON records in the "JSON Lines" format + (https://jsonlines.org/). + + When it takes longer than 15 seconds to get a response from the DB, + sends a keep-alive message to prevent clients from timing out. + """ + # Ensure that the database connection is cleaned up by taking control of + # exit_stack. async with contextmanager_in_threadpool(exit_stack): - async for chunk in iterate_in_threadpool(serialize_query_pages(spec, pages)): - yield chunk + iterator = iterate_in_threadpool(serialize_query_pages(spec, pages)) + done = False + while not done: + # Read the next value from the iterator, possibly with some + # additional keep-alive messages if it takes a long time. + async for message in _fetch_next_with_keepalives(iterator): + if message is None: + done = True + else: + yield message + yield "\n" + + +async def _fetch_next_with_keepalives(iterator: AsyncIterator[str]) -> AsyncIterator[str | None]: + """Read the next value from the given iterator and yield it. Yields a + keep-alive message every 15 seconds while waiting for the iterator to + return a value. Yields `None` if there is nothing left to read from the + iterator. + """ + try: + future = asyncio.ensure_future(anext(iterator, None)) + ready = False + while not ready: + (finished_task, pending_task) = await asyncio.wait([future], timeout=15) + if pending_task: + # Hit the timeout, send a keep-alive and keep waiting. + yield QueryKeepAliveModel().model_dump_json() + else: + # The next value from the iterator is ready to read. + ready = True + finally: + # Even if we get cancelled above, we need to wait for this iteration to + # complete so we don't have a dangling thread using a database + # connection that the caller is about to clean up. + result = await future + + yield result @query_router.post( diff --git a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py index e2db3399d5..b51599b946 100644 --- a/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py +++ b/python/lsst/daf/butler/remote_butler/server/handlers/_query_serialization.py @@ -53,18 +53,15 @@ def serialize_query_pages( spec: ResultSpec, pages: Iterable[ResultPage] ) -> Iterator[str]: # numpydoc ignore=PR01 """Serialize result pages to pages of result data in JSON format. The - output contains one page object per line, as newline-delimited JSON records - in the "JSON Lines" format (https://jsonlines.org/). + output contains one page object per iteration. """ try: for page in pages: yield _convert_query_page(spec, page).model_dump_json() - yield "\n" except ButlerUserError as e: # If a user-facing error occurs, serialize it and send it to the # client. yield QueryErrorResultModel(error=serialize_butler_user_error(e)).model_dump_json() - yield "\n" def _convert_query_page(spec: ResultSpec, page: ResultPage) -> QueryExecuteResultData: diff --git a/python/lsst/daf/butler/remote_butler/server_models.py b/python/lsst/daf/butler/remote_butler/server_models.py index 876bfe9839..5ae692506f 100644 --- a/python/lsst/daf/butler/remote_butler/server_models.py +++ b/python/lsst/daf/butler/remote_butler/server_models.py @@ -299,12 +299,24 @@ class QueryErrorResultModel(pydantic.BaseModel): error: ErrorResponseModel +class QueryKeepAliveModel(pydantic.BaseModel): + """Result model for /query/execute used to keep connection alive. + + Some queries require a significant start-up time before they can start + returning results, or a long processing time for each chunk of rows. This + message signals that the server is still fetching the data. + """ + + type: Literal["keep-alive"] = "keep-alive" + + QueryExecuteResultData: TypeAlias = Annotated[ DataCoordinateResultModel | DimensionRecordsResultModel | DatasetRefResultModel | GeneralResultModel - | QueryErrorResultModel, + | QueryErrorResultModel + | QueryKeepAliveModel, pydantic.Field(discriminator="type"), ]