From 1894693bda20bf9c44d54ecbb499282aa7f618e5 Mon Sep 17 00:00:00 2001 From: Dan Allan Date: Tue, 25 Jul 2023 08:41:35 -0400 Subject: [PATCH] Remove vendored copy of starlette test client (#529) * Remove vendored copy, no longer needed. * Test that multiple chunks come from streaming response. * Fix misuse of zstandard. --- tiled/_tests/test_export.py | 4 +- tiled/client/_testclient.py | 805 ------------------------------------ tiled/client/context.py | 7 +- tiled/client/decoders.py | 5 +- 4 files changed, 10 insertions(+), 811 deletions(-) delete mode 100644 tiled/client/_testclient.py diff --git a/tiled/_tests/test_export.py b/tiled/_tests/test_export.py index 409b62635..22d4c979b 100644 --- a/tiled/_tests/test_export.py +++ b/tiled/_tests/test_export.py @@ -93,7 +93,9 @@ def test_streaming_export(client): client["C"].export(buffer, format="application/json-seq") # Verify that output is valid newline-delimited JSON. buffer.seek(0) - for line in buffer.read().decode().splitlines(): + lines = buffer.read().decode().splitlines() + assert len(lines) == 100 + for line in lines: json.loads(line) diff --git a/tiled/client/_testclient.py b/tiled/client/_testclient.py deleted file mode 100644 index 20afcffea..000000000 --- a/tiled/client/_testclient.py +++ /dev/null @@ -1,805 +0,0 @@ -""" -Vendored from starlette.testclient and starlette._utils -commit 3ac06f04c1096052d7492401b33286f6c71fb7ad -because TestClient was recently rewritten to use httpx instead of requests -and that rewrite has not yet been released. -""" -import asyncio -import contextlib -import functools -import inspect -import io -import json -import math -import queue -import sys -import typing -import warnings -from concurrent.futures import Future -from types import GeneratorType -from urllib.parse import unquote, urljoin - -import anyio -import httpx -from anyio.streams.stapled import StapledObjectStream -from starlette.types import ASGIApp, Message, Receive, Scope, Send -from starlette.websockets import WebSocketDisconnect - -if sys.version_info >= (3, 8): # pragma: no cover - from typing import TypedDict -else: # pragma: no cover - from typing_extensions import TypedDict - -_PortalFactoryType = typing.Callable[ - [], typing.ContextManager[anyio.abc.BlockingPortal] -] - -ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]] -ASGI2App = typing.Callable[[Scope], ASGIInstance] -ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] - - -_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]] - - -def is_async_callable(obj: typing.Any) -> bool: - while isinstance(obj, functools.partial): - obj = obj.func - - return asyncio.iscoroutinefunction(obj) or ( - callable(obj) and asyncio.iscoroutinefunction(obj.__call__) - ) - - -def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool: - if inspect.isclass(app): - return hasattr(app, "__await__") - return is_async_callable(app) - - -class _WrapASGI2: - """ - Provide an ASGI3 interface onto an ASGI2 app. - """ - - def __init__(self, app: ASGI2App) -> None: - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - instance = self.app(scope) - await instance(receive, send) - - -class _AsyncBackend(TypedDict): - backend: str - backend_options: typing.Dict[str, typing.Any] - - -class _Upgrade(Exception): - def __init__(self, session: "WebSocketTestSession") -> None: - self.session = session - - -class WebSocketTestSession: - def __init__( - self, - app: ASGI3App, - scope: Scope, - portal_factory: _PortalFactoryType, - ) -> None: - self.app = app - self.scope = scope - self.accepted_subprotocol = None - self.portal_factory = portal_factory - self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() - self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() - self.extra_headers = None - - def __enter__(self) -> "WebSocketTestSession": - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context(self.portal_factory()) - - try: - _: "Future[None]" = self.portal.start_task_soon(self._run) - self.send({"type": "websocket.connect"}) - message = self.receive() - self._raise_on_close(message) - except Exception: - self.exit_stack.close() - raise - self.accepted_subprotocol = message.get("subprotocol", None) - self.extra_headers = message.get("headers", None) - return self - - def __exit__(self, *args: typing.Any) -> None: - try: - self.close(1000) - finally: - self.exit_stack.close() - while not self._send_queue.empty(): - message = self._send_queue.get() - if isinstance(message, BaseException): - raise message - - async def _run(self) -> None: - """ - The sub-thread in which the websocket session runs. - """ - scope = self.scope - receive = self._asgi_receive - send = self._asgi_send - try: - await self.app(scope, receive, send) - except BaseException as exc: - self._send_queue.put(exc) - raise - - async def _asgi_receive(self) -> Message: - while self._receive_queue.empty(): - await anyio.sleep(0) - return self._receive_queue.get() - - async def _asgi_send(self, message: Message) -> None: - self._send_queue.put(message) - - def _raise_on_close(self, message: Message) -> None: - if message["type"] == "websocket.close": - raise WebSocketDisconnect( - message.get("code", 1000), message.get("reason", "") - ) - - def send(self, message: Message) -> None: - self._receive_queue.put(message) - - def send_text(self, data: str) -> None: - self.send({"type": "websocket.receive", "text": data}) - - def send_bytes(self, data: bytes) -> None: - self.send({"type": "websocket.receive", "bytes": data}) - - def send_json(self, data: typing.Any, mode: str = "text") -> None: - assert mode in ["text", "binary"] - text = json.dumps(data) - if mode == "text": - self.send({"type": "websocket.receive", "text": text}) - else: - self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) - - def close(self, code: int = 1000) -> None: - self.send({"type": "websocket.disconnect", "code": code}) - - def receive(self) -> Message: - message = self._send_queue.get() - if isinstance(message, BaseException): - raise message - return message - - def receive_text(self) -> str: - message = self.receive() - self._raise_on_close(message) - return message["text"] - - def receive_bytes(self) -> bytes: - message = self.receive() - self._raise_on_close(message) - return message["bytes"] - - def receive_json(self, mode: str = "text") -> typing.Any: - assert mode in ["text", "binary"] - message = self.receive() - self._raise_on_close(message) - if mode == "text": - text = message["text"] - else: - text = message["bytes"].decode("utf-8") - return json.loads(text) - - -class _TestClientTransport(httpx.BaseTransport): - def __init__( - self, - app: ASGI3App, - portal_factory: _PortalFactoryType, - raise_server_exceptions: bool = True, - root_path: str = "", - ) -> None: - self.app = app - self.raise_server_exceptions = raise_server_exceptions - self.root_path = root_path - self.portal_factory = portal_factory - - def handle_request(self, request: httpx.Request) -> httpx.Response: - scheme = request.url.scheme - netloc = unquote(request.url.netloc.decode(encoding="ascii")) - path = request.url.path - raw_path = request.url.raw_path - query = unquote(request.url.query.decode(encoding="ascii")) - - default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] - - if ":" in netloc: - host, port_string = netloc.split(":", 1) - port = int(port_string) - else: - host = netloc - port = default_port - - # Include the 'host' header. - if "host" in request.headers: - headers: typing.List[typing.Tuple[bytes, bytes]] = [] - elif port == default_port: # pragma: no cover - headers = [(b"host", host.encode())] - else: # pragma: no cover - headers = [(b"host", (f"{host}:{port}").encode())] - - # Include other request headers. - headers += [ - (key.lower().encode(), value.encode()) - for key, value in request.headers.items() - ] - - scope: typing.Dict[str, typing.Any] - - if scheme in {"ws", "wss"}: - subprotocol = request.headers.get("sec-websocket-protocol", None) - if subprotocol is None: - subprotocols: typing.Sequence[str] = [] - else: - subprotocols = [value.strip() for value in subprotocol.split(",")] - scope = { - "type": "websocket", - "path": unquote(path), - "raw_path": raw_path, - "root_path": self.root_path, - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ["testclient", 50000], - "server": [host, port], - "subprotocols": subprotocols, - } - session = WebSocketTestSession(self.app, scope, self.portal_factory) - raise _Upgrade(session) - - scope = { - "type": "http", - "http_version": "1.1", - "method": request.method, - "path": unquote(path), - "raw_path": raw_path, - "root_path": self.root_path, - "scheme": scheme, - "query_string": query.encode(), - "headers": headers, - "client": ["testclient", 50000], - "server": [host, port], - "extensions": {"http.response.template": {}}, - } - - request_complete = False - response_started = False - response_complete: anyio.Event - raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()} - template = None - context = None - - async def receive() -> Message: - nonlocal request_complete - - if request_complete: - if not response_complete.is_set(): - await response_complete.wait() - return {"type": "http.disconnect"} - - body = request.read() - if isinstance(body, str): - body_bytes: bytes = body.encode("utf-8") # pragma: no cover - elif body is None: - body_bytes = b"" # pragma: no cover - elif isinstance(body, GeneratorType): - try: # pragma: no cover - chunk = body.send(None) - if isinstance(chunk, str): - chunk = chunk.encode("utf-8") - return {"type": "http.request", "body": chunk, "more_body": True} - except StopIteration: # pragma: no cover - request_complete = True - return {"type": "http.request", "body": b""} - else: - body_bytes = body - - request_complete = True - return {"type": "http.request", "body": body_bytes} - - async def send(message: Message) -> None: - nonlocal raw_kwargs, response_started, template, context - - if message["type"] == "http.response.start": - assert ( - not response_started - ), 'Received multiple "http.response.start" messages.' - raw_kwargs["status_code"] = message["status"] - raw_kwargs["headers"] = [ - (key.decode(), value.decode()) - for key, value in message.get("headers", []) - ] - response_started = True - elif message["type"] == "http.response.body": - assert ( - response_started - ), 'Received "http.response.body" without "http.response.start".' - assert ( - not response_complete.is_set() - ), 'Received "http.response.body" after response completed.' - body = message.get("body", b"") - more_body = message.get("more_body", False) - if request.method != "HEAD": - raw_kwargs["stream"].write(body) - if not more_body: - raw_kwargs["stream"].seek(0) - response_complete.set() - elif message["type"] == "http.response.template": - template = message["template"] - context = message["context"] - - try: - with self.portal_factory() as portal: - response_complete = portal.call(anyio.Event) - portal.call(self.app, scope, receive, send) - except BaseException as exc: - if self.raise_server_exceptions: - raise exc - - if self.raise_server_exceptions: - assert response_started, "TestClient did not receive any response." - elif not response_started: - raw_kwargs = { - "status_code": 500, - "headers": [], - "stream": io.BytesIO(), - } - - raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) - - response = httpx.Response(**raw_kwargs, request=request) - if template is not None: - response.template = template # type: ignore[attr-defined] - response.context = context # type: ignore[attr-defined] - return response - - -class TestClient(httpx.Client): - __test__ = False - task: "Future[None]" - portal: typing.Optional[anyio.abc.BlockingPortal] = None - - def __init__( - self, - app: ASGIApp, - base_url: str = "http://testserver", - raise_server_exceptions: bool = True, - root_path: str = "", - backend: str = "asyncio", - backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, - cookies: httpx._client.CookieTypes = None, - event_hooks=None, - timeout=None, - params=None, - headers=None, - ) -> None: - self.async_backend = _AsyncBackend( - backend=backend, backend_options=backend_options or {} - ) - if _is_asgi3(app): - app = typing.cast(ASGI3App, app) - asgi_app = app - else: - app = typing.cast(ASGI2App, app) # type: ignore[assignment] - asgi_app = _WrapASGI2(app) # type: ignore[arg-type] - self.app = asgi_app - transport = _TestClientTransport( - self.app, - portal_factory=self._portal_factory, - raise_server_exceptions=raise_server_exceptions, - root_path=root_path, - ) - super().__init__( - app=self.app, - base_url=base_url, - headers={"user-agent": "testclient"}, - params=params, - transport=transport, - follow_redirects=True, - cookies=cookies, - event_hooks=event_hooks, - timeout=timeout, - ) - - @contextlib.contextmanager - def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]: - if self.portal is not None: - yield self.portal - else: - with anyio.start_blocking_portal(**self.async_backend) as portal: - yield portal - - def _choose_redirect_arg( - self, - follow_redirects: typing.Optional[bool], - allow_redirects: typing.Optional[bool], - ) -> typing.Union[bool, httpx._client.UseClientDefault]: - redirect: typing.Union[ - bool, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT - if allow_redirects is not None: - message = ( - "The `allow_redirects` argument is deprecated. " - "Use `follow_redirects` instead." - ) - warnings.warn(message, DeprecationWarning) - redirect = allow_redirects - if follow_redirects is not None: - redirect = follow_redirects - elif allow_redirects is not None and follow_redirects is not None: - raise RuntimeError( # pragma: no cover - "Cannot use both `allow_redirects` and `follow_redirects`." - ) - return redirect - - def request( # type: ignore[override] - self, - method: str, - url: httpx._types.URLTypes, - *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, - json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - url = self.base_url.join(url) - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().request( - method, - url, - content=content, - data=data, # type: ignore[arg-type] - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def get( # type: ignore[override] - self, - url: httpx._types.URLTypes, - *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().get( - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def options( # type: ignore[override] - self, - url: httpx._types.URLTypes, - *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().options( - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def head( # type: ignore[override] - self, - url: httpx._types.URLTypes, - *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().head( - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def post( # type: ignore[override] - self, - url: httpx._types.URLTypes, - *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, - json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().post( - url, - content=content, - data=data, # type: ignore[arg-type] - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def put( # type: ignore[override] - self, - url: httpx._types.URLTypes, - *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, - json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().put( - url, - content=content, - data=data, # type: ignore[arg-type] - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def patch( # type: ignore[override] - self, - url: httpx._types.URLTypes, - *, - content: typing.Optional[httpx._types.RequestContent] = None, - data: typing.Optional[_RequestData] = None, - files: typing.Optional[httpx._types.RequestFiles] = None, - json: typing.Any = None, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().patch( - url, - content=content, - data=data, # type: ignore[arg-type] - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def delete( # type: ignore[override] - self, - url: httpx._types.URLTypes, - *, - params: typing.Optional[httpx._types.QueryParamTypes] = None, - headers: typing.Optional[httpx._types.HeaderTypes] = None, - cookies: typing.Optional[httpx._types.CookieTypes] = None, - auth: typing.Union[ - httpx._types.AuthTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - follow_redirects: typing.Optional[bool] = None, - allow_redirects: typing.Optional[bool] = None, - timeout: typing.Union[ - httpx._client.TimeoutTypes, httpx._client.UseClientDefault - ] = httpx._client.USE_CLIENT_DEFAULT, - extensions: typing.Optional[typing.Dict[str, typing.Any]] = None, - ) -> httpx.Response: - redirect = self._choose_redirect_arg(follow_redirects, allow_redirects) - return super().delete( - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=redirect, - timeout=timeout, - extensions=extensions, - ) - - def websocket_connect( - self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any - ) -> typing.Any: - url = urljoin("ws://testserver", url) - headers = kwargs.get("headers", {}) - headers.setdefault("connection", "upgrade") - headers.setdefault("sec-websocket-key", "testserver==") - headers.setdefault("sec-websocket-version", "13") - if subprotocols is not None: - headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) - kwargs["headers"] = headers - try: - super().request("GET", url, **kwargs) - except _Upgrade as exc: - session = exc.session - else: - raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover - - return session - - def __enter__(self) -> "TestClient": - with contextlib.ExitStack() as stack: - self.portal = portal = stack.enter_context( - anyio.start_blocking_portal(**self.async_backend) - ) - - @stack.callback - def reset_portal() -> None: - self.portal = None - - self.stream_send = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) - self.stream_receive = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) - self.task = portal.start_task_soon(self.lifespan) - portal.call(self.wait_startup) - - @stack.callback - def wait_shutdown() -> None: - portal.call(self.wait_shutdown) - - self.exit_stack = stack.pop_all() - - return self - - def __exit__(self, *args: typing.Any) -> None: - self.exit_stack.close() - - async def lifespan(self) -> None: - scope = {"type": "lifespan"} - try: - await self.app(scope, self.stream_receive.receive, self.stream_send.send) - finally: - await self.stream_send.send(None) - - async def wait_startup(self) -> None: - await self.stream_receive.send({"type": "lifespan.startup"}) - - async def receive() -> typing.Any: - message = await self.stream_send.receive() - if message is None: - self.task.result() - return message - - message = await receive() - assert message["type"] in ( - "lifespan.startup.complete", - "lifespan.startup.failed", - ) - if message["type"] == "lifespan.startup.failed": - await receive() - - async def wait_shutdown(self) -> None: - async def receive() -> typing.Any: - message = await self.stream_send.receive() - if message is None: - self.task.result() - return message - - async with self.stream_send: - await self.stream_receive.send({"type": "lifespan.shutdown"}) - message = await receive() - assert message["type"] in ( - "lifespan.shutdown.complete", - "lifespan.shutdown.failed", - ) - if message["type"] == "lifespan.shutdown.failed": - await receive() diff --git a/tiled/client/context.py b/tiled/client/context.py index 4b2703f8d..018765028 100644 --- a/tiled/client/context.py +++ b/tiled/client/context.py @@ -103,14 +103,17 @@ def __init__( # Do this in the setter to avoid being overwritten. client.headers = headers else: - from ._testclient import TestClient + # Set up an ASGI client. + # Because we have been handed an app, we can infer that + # starlette is available. + from starlette.testclient import TestClient # verify parameter is dropped, as there is no SSL in ASGI mode client = TestClient( app=app, - timeout=timeout, raise_server_exceptions=raise_server_exceptions, ) + client.timeout = timeout client.headers = headers # Do this in the setter to avoid being overwritten. client.follow_redirects = True diff --git a/tiled/client/decoders.py b/tiled/client/decoders.py index 5e480a2c0..28c9ed94e 100644 --- a/tiled/client/decoders.py +++ b/tiled/client/decoders.py @@ -33,11 +33,10 @@ def flush(self) -> bytes: class ZStandardDecoder: def __init__(self): - self.context = zstandard.ZstdDecompressor() - self.decompressobj = self.context.decompressobj() + self._context = zstandard.ZstdDecompressor() def decode(self, data: bytes) -> bytes: - return self.decompressobj.decompress(data) + return self._context.decompress(data) def flush(self) -> bytes: return b""