From dc421efae4ff8ec9fb156ccdc014f8af35f3a9c5 Mon Sep 17 00:00:00 2001 From: Casper van der Wel Date: Wed, 4 Oct 2023 15:40:36 +0200 Subject: [PATCH] Add async equivalents for API interfacing (#20) --- CHANGES.md | 4 +- clean_python/api_client/__init__.py | 1 + clean_python/api_client/api_gateway.py | 75 ++++++++- clean_python/api_client/api_provider.py | 98 +++++++++--- clean_python/api_client/response.py | 12 ++ clean_python/api_client/sync_api_provider.py | 115 ++++++++++++++ integration_tests/test_int_api_gateway.py | 62 ++++++++ integration_tests/test_int_api_provider.py | 103 +++++++++++++ ...ateway.py => test_int_sync_api_gateway.py} | 2 + ...vider.py => test_int_sync_api_provider.py} | 10 ++ pyproject.toml | 2 +- tests/api_client/test_api_gateway.py | 131 ++++++++++++++++ tests/api_client/test_api_provider.py | 144 ++++++++++++++++++ tests/api_client/test_sync_api_gateway.py | 4 +- tests/api_client/test_sync_api_provider.py | 4 +- 15 files changed, 735 insertions(+), 32 deletions(-) create mode 100644 clean_python/api_client/response.py create mode 100644 clean_python/api_client/sync_api_provider.py create mode 100644 integration_tests/test_int_api_gateway.py create mode 100644 integration_tests/test_int_api_provider.py rename integration_tests/{test_api_gateway.py => test_int_sync_api_gateway.py} (96%) rename integration_tests/{test_api_provider.py => test_int_sync_api_provider.py} (89%) create mode 100644 tests/api_client/test_api_gateway.py create mode 100644 tests/api_client/test_api_provider.py diff --git a/CHANGES.md b/CHANGES.md index f7abd11..82fe041 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,7 +4,9 @@ 0.6.5 (unreleased) ------------------ -- Nothing changed yet. +- Added async `ApiProvider` and `ApiGateway`. + +- Added `request_raw` to `ApiProvider` for handling arbitrary responses. 0.6.4 (2023-10-03) diff --git a/clean_python/api_client/__init__.py b/clean_python/api_client/__init__.py index 9b174b9..3052831 100644 --- a/clean_python/api_client/__init__.py +++ b/clean_python/api_client/__init__.py @@ -2,3 +2,4 @@ from .api_provider import * # NOQA from .exceptions import * # NOQA from .files import * # NOQA +from .sync_api_provider import * # NOQA diff --git a/clean_python/api_client/api_gateway.py b/clean_python/api_client/api_gateway.py index b05fb8b..39ccc86 100644 --- a/clean_python/api_client/api_gateway.py +++ b/clean_python/api_client/api_gateway.py @@ -5,15 +5,84 @@ import inject from clean_python import DoesNotExist +from clean_python import Gateway from clean_python import Id from clean_python import Json from clean_python import Mapper +from clean_python import SyncGateway -from .. import SyncGateway -from .api_provider import SyncApiProvider +from .api_provider import ApiProvider from .exceptions import ApiException +from .sync_api_provider import SyncApiProvider -__all__ = ["SyncApiGateway"] +__all__ = ["ApiGateway", "SyncApiGateway"] + + +class ApiGateway(Gateway): + path: str + mapper = Mapper() + + def __init__(self, provider_override: Optional[ApiProvider] = None): + self.provider_override = provider_override + + def __init_subclass__(cls, path: str) -> None: + assert not path.startswith("/") + assert "{id}" in path + cls.path = path + super().__init_subclass__() + + @property + def provider(self) -> ApiProvider: + return self.provider_override or inject.instance(ApiProvider) + + async def get(self, id: Id) -> Optional[Json]: + try: + result = await self.provider.request("GET", self.path.format(id=id)) + assert result is not None + return self.mapper.to_internal(result) + except ApiException as e: + if e.status is HTTPStatus.NOT_FOUND: + return None + raise e + + async def add(self, item: Json) -> Json: + item = self.mapper.to_external(item) + result = await self.provider.request("POST", self.path.format(id=""), json=item) + assert result is not None + return self.mapper.to_internal(result) + + async def remove(self, id: Id) -> bool: + try: + await self.provider.request("DELETE", self.path.format(id=id)) is not None + except ApiException as e: + if e.status is HTTPStatus.NOT_FOUND: + return False + raise e + else: + return True + + async def update( + self, item: Json, if_unmodified_since: Optional[datetime] = None + ) -> Json: + if if_unmodified_since is not None: + raise NotImplementedError("if_unmodified_since not implemented") + item = self.mapper.to_external(item) + id_ = item.pop("id", None) + if id_ is None: + raise DoesNotExist("resource", id_) + try: + result = await self.provider.request( + "PATCH", self.path.format(id=id_), json=item + ) + assert result is not None + return self.mapper.to_internal(result) + except ApiException as e: + if e.status is HTTPStatus.NOT_FOUND: + raise DoesNotExist("resource", id_) + raise e + + +# This is a copy-paste of ApiGateway: class SyncApiGateway(SyncGateway): diff --git a/clean_python/api_client/api_provider.py b/clean_python/api_client/api_provider.py index d002572..538b1ea 100644 --- a/clean_python/api_client/api_provider.py +++ b/clean_python/api_client/api_provider.py @@ -1,4 +1,4 @@ -import json as json_lib +import asyncio import re from http import HTTPStatus from typing import Callable @@ -7,16 +7,21 @@ from urllib.parse import urlencode from urllib.parse import urljoin +import aiohttp +from aiohttp import ClientResponse +from aiohttp import ClientSession from pydantic import AnyHttpUrl -from urllib3 import PoolManager -from urllib3 import Retry from clean_python import ctx from clean_python import Json from .exceptions import ApiException +from .response import Response -__all__ = ["SyncApiProvider"] +__all__ = ["ApiProvider"] + + +RETRY_STATUSES = frozenset({413, 429, 503}) # like in urllib3 def is_success(status: HTTPStatus) -> bool: @@ -49,7 +54,7 @@ def add_query_params(url: str, params: Optional[Json]) -> str: return url + "?" + urlencode(params, doseq=True) -class SyncApiProvider: +class ApiProvider: """Basic JSON API provider with retry policy and bearer tokens. The default retry policy has 3 retries with 1, 2, 4 second intervals. @@ -64,43 +69,70 @@ class SyncApiProvider: def __init__( self, url: AnyHttpUrl, - fetch_token: Callable[[PoolManager, int], Optional[str]], + fetch_token: Callable[[ClientSession, int], Optional[str]], retries: int = 3, backoff_factor: float = 1.0, ): self._url = str(url) assert self._url.endswith("/") self._fetch_token = fetch_token - self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor)) + assert retries > 0 + self._retries = retries + self._backoff_factor = backoff_factor + self._session = ClientSession() - def request( + async def _request_with_retry( self, method: str, path: str, - params: Optional[Json] = None, - json: Optional[Json] = None, - fields: Optional[Json] = None, - timeout: float = 5.0, - ) -> Optional[Json]: + params: Optional[Json], + json: Optional[Json], + fields: Optional[Json], + timeout: float, + ) -> ClientResponse: assert ctx.tenant is not None headers = {} request_kwargs = { "method": method, "url": add_query_params(join(self._url, quote(path)), params), "timeout": timeout, + "json": json, + "data": fields, } - # for urllib3<2, we dump json ourselves - if json is not None and fields is not None: - raise ValueError("Cannot both specify 'json' and 'fields'") - elif json is not None: - request_kwargs["body"] = json_lib.dumps(json).encode() - headers["Content-Type"] = "application/json" - elif fields is not None: - request_kwargs["fields"] = fields - token = self._fetch_token(self._pool, ctx.tenant.id) + token = self._fetch_token(self._session, ctx.tenant.id) if token is not None: headers["Authorization"] = f"Bearer {token}" - response = self._pool.request(headers=headers, **request_kwargs) + for attempt in range(self._retries): + if attempt > 0: + backoff = self._backoff_factor * 2 ** (attempt - 1) + await asyncio.sleep(backoff) + + try: + response = await self._session.request( + headers=headers, **request_kwargs + ) + await response.read() + except (aiohttp.ClientError, asyncio.exceptions.TimeoutError): + if attempt == self._retries - 1: + raise # propagate ClientError in case no retries left + else: + if response.status not in RETRY_STATUSES: + return response # on all non-retry statuses: return response + + return response # retries exceeded; return the (possibly error) response + + async def request( + self, + method: str, + path: str, + params: Optional[Json] = None, + json: Optional[Json] = None, + fields: Optional[Json] = None, + timeout: float = 5.0, + ) -> Optional[Json]: + response = await self._request_with_retry( + method, path, params, json, fields, timeout + ) status = HTTPStatus(response.status) content_type = response.headers.get("Content-Type") if status is HTTPStatus.NO_CONTENT: @@ -109,8 +141,26 @@ def request( raise ApiException( f"Unexpected content type '{content_type}'", status=status ) - body = json_lib.loads(response.data.decode()) + body = await response.json() if is_success(status): return body else: raise ApiException(body, status=status) + + async def request_raw( + self, + method: str, + path: str, + params: Optional[Json] = None, + json: Optional[Json] = None, + fields: Optional[Json] = None, + timeout: float = 5.0, + ) -> Response: + response = await self._request_with_retry( + method, path, params, json, fields, timeout + ) + return Response( + status=response.status, + data=await response.read(), + content_type=response.headers.get("Content-Type"), + ) diff --git a/clean_python/api_client/response.py b/clean_python/api_client/response.py new file mode 100644 index 0000000..4c6cba9 --- /dev/null +++ b/clean_python/api_client/response.py @@ -0,0 +1,12 @@ +from http import HTTPStatus +from typing import Optional + +from clean_python import ValueObject + +__all__ = ["Response"] + + +class Response(ValueObject): + status: HTTPStatus + data: bytes + content_type: Optional[str] diff --git a/clean_python/api_client/sync_api_provider.py b/clean_python/api_client/sync_api_provider.py new file mode 100644 index 0000000..6e25cc0 --- /dev/null +++ b/clean_python/api_client/sync_api_provider.py @@ -0,0 +1,115 @@ +import json as json_lib +from http import HTTPStatus +from typing import Callable +from typing import Optional +from urllib.parse import quote + +from pydantic import AnyHttpUrl +from urllib3 import PoolManager +from urllib3 import Retry + +from clean_python import ctx +from clean_python import Json + +from .api_provider import add_query_params +from .api_provider import is_json_content_type +from .api_provider import is_success +from .api_provider import join +from .exceptions import ApiException +from .response import Response + +__all__ = ["SyncApiProvider"] + + +class SyncApiProvider: + """Basic JSON API provider with retry policy and bearer tokens. + + The default retry policy has 3 retries with 1, 2, 4 second intervals. + + Args: + url: The url of the API (with trailing slash) + fetch_token: Callable that returns a token for a tenant id + retries: Total number of retries per request + backoff_factor: Multiplier for retry delay times (1, 2, 4, ...) + """ + + def __init__( + self, + url: AnyHttpUrl, + fetch_token: Callable[[PoolManager, int], Optional[str]], + retries: int = 3, + backoff_factor: float = 1.0, + ): + self._url = str(url) + assert self._url.endswith("/") + self._fetch_token = fetch_token + self._pool = PoolManager(retries=Retry(retries, backoff_factor=backoff_factor)) + + def _request( + self, + method: str, + path: str, + params: Optional[Json], + json: Optional[Json], + fields: Optional[Json], + timeout: float, + ): + assert ctx.tenant is not None + headers = {} + request_kwargs = { + "method": method, + "url": add_query_params(join(self._url, quote(path)), params), + "timeout": timeout, + } + # for urllib3<2, we dump json ourselves + if json is not None and fields is not None: + raise ValueError("Cannot both specify 'json' and 'fields'") + elif json is not None: + request_kwargs["body"] = json_lib.dumps(json).encode() + headers["Content-Type"] = "application/json" + elif fields is not None: + request_kwargs["fields"] = fields + token = self._fetch_token(self._pool, ctx.tenant.id) + if token is not None: + headers["Authorization"] = f"Bearer {token}" + return self._pool.request(headers=headers, **request_kwargs) + + def request( + self, + method: str, + path: str, + params: Optional[Json] = None, + json: Optional[Json] = None, + fields: Optional[Json] = None, + timeout: float = 5.0, + ) -> Optional[Json]: + response = self._request(method, path, params, json, fields, timeout) + status = HTTPStatus(response.status) + content_type = response.headers.get("Content-Type") + if status is HTTPStatus.NO_CONTENT: + return None + if not is_json_content_type(content_type): + raise ApiException( + f"Unexpected content type '{content_type}'", status=status + ) + body = json_lib.loads(response.data.decode()) + if is_success(status): + return body + else: + raise ApiException(body, status=status) + + def request_raw( + self, + method: str, + path: str, + params: Optional[Json] = None, + json: Optional[Json] = None, + fields: Optional[Json] = None, + timeout: float = 5.0, + ) -> Response: + response = self._request(method, path, params, json, fields, timeout) + return Response( + status=response.status, + data=response.data, + content_type=response.headers.get("Content-Type"), + ) diff --git a/integration_tests/test_int_api_gateway.py b/integration_tests/test_int_api_gateway.py new file mode 100644 index 0000000..09d75d4 --- /dev/null +++ b/integration_tests/test_int_api_gateway.py @@ -0,0 +1,62 @@ +import pytest + +from clean_python import ctx +from clean_python import DoesNotExist +from clean_python import Json +from clean_python import Tenant +from clean_python.api_client import ApiGateway +from clean_python.api_client import ApiProvider + + +class BooksGateway(ApiGateway, path="v1/books/{id}"): + pass + + +@pytest.fixture +def provider(fastapi_example_app) -> ApiProvider: + ctx.tenant = Tenant(id=2, name="") + yield ApiProvider(fastapi_example_app + "/", lambda a, b: "token") + ctx.tenant = None + + +@pytest.fixture +def gateway(provider) -> ApiGateway: + return BooksGateway(provider) + + +@pytest.fixture +async def book(gateway: ApiGateway): + return await gateway.add({"title": "fixture", "author": {"name": "foo"}}) + + +async def test_add(gateway: ApiGateway): + response = await gateway.add({"title": "test_add", "author": {"name": "foo"}}) + assert isinstance(response["id"], int) + assert response["title"] == "test_add" + assert response["author"] == {"name": "foo"} + assert response["created_at"] == response["updated_at"] + + +async def test_get(gateway: ApiGateway, book: Json): + response = await gateway.get(book["id"]) + assert response == book + + +async def test_remove_and_404(gateway: ApiGateway, book: Json): + assert await gateway.remove(book["id"]) is True + assert await gateway.get(book["id"]) is None + assert await gateway.remove(book["id"]) is False + + +async def test_update(gateway: ApiGateway, book: Json): + response = await gateway.update({"id": book["id"], "title": "test_update"}) + + assert response["id"] == book["id"] + assert response["title"] == "test_update" + assert response["author"] == {"name": "foo"} + assert response["created_at"] != response["updated_at"] + + +async def test_update_404(gateway: ApiGateway): + with pytest.raises(DoesNotExist): + await gateway.update({"id": 123456, "title": "test_update_404"}) diff --git a/integration_tests/test_int_api_provider.py b/integration_tests/test_int_api_provider.py new file mode 100644 index 0000000..b9d48e6 --- /dev/null +++ b/integration_tests/test_int_api_provider.py @@ -0,0 +1,103 @@ +from http import HTTPStatus + +import pytest + +from clean_python import ctx +from clean_python import Tenant +from clean_python.api_client import ApiException +from clean_python.api_client import ApiProvider + + +@pytest.fixture +def provider(fastapi_example_app) -> ApiProvider: + ctx.tenant = Tenant(id=2, name="") + yield ApiProvider(fastapi_example_app + "/", lambda a, b: "token") + ctx.tenant = None + + +async def test_request_params(provider: ApiProvider): + response = await provider.request( + "GET", "v1/books", params={"limit": 10, "offset": 2} + ) + + assert isinstance(response, dict) + + assert response["limit"] == 10 + assert response["offset"] == 2 + + +async def test_request_json_body(provider: ApiProvider): + response = await provider.request( + "POST", "v1/books", json={"title": "test_body", "author": {"name": "foo"}} + ) + + assert isinstance(response, dict) + assert response["title"] == "test_body" + assert response["author"] == {"name": "foo"} + + +async def test_request_form_body(provider: ApiProvider): + response = await provider.request("POST", "v1/form", fields={"name": "foo"}) + + assert isinstance(response, dict) + assert response["name"] == "foo" + + +# files are not supported (yet) +# +# async def test_request_form_file(provider: ApiProvider): +# response = await provider.request("POST", "v1/file", fields={"file": ("x.txt", b"foo")}) + +# assert isinstance(response, dict) +# assert response["x.txt"] == "foo" + + +@pytest.fixture +async def book(provider: ApiProvider): + return await provider.request( + "POST", "v1/books", json={"title": "fixture", "author": {"name": "foo"}} + ) + + +async def test_no_content(provider: ApiProvider, book): + response = await provider.request("DELETE", f"v1/books/{book['id']}") + + assert response is None + + +async def test_not_found(provider: ApiProvider): + with pytest.raises(ApiException) as e: + await provider.request("GET", "v1/book") + + assert e.value.status is HTTPStatus.NOT_FOUND + assert e.value.args[0] == {"detail": "Not Found"} + + +async def test_bad_request(provider: ApiProvider): + with pytest.raises(ApiException) as e: + await provider.request("GET", "v1/books", params={"limit": "foo"}) + + assert e.value.status is HTTPStatus.BAD_REQUEST + assert e.value.args[0]["detail"][0]["loc"] == ["query", "limit"] + + +async def test_no_json_response(provider: ApiProvider): + with pytest.raises(ApiException) as e: + await provider.request("GET", "v1/text") + + assert e.value.args[0] == "Unexpected content type 'text/plain; charset=utf-8'" + + +async def test_urlencode(provider: ApiProvider): + response = await provider.request("PUT", "v1/urlencode/x?") + + assert isinstance(response, dict) + assert response["name"] == "x?" + + +async def test_request_raw(provider: ApiProvider, book): + response = await provider.request_raw("GET", f"v1/books/{book['id']}") + + assert response.status is HTTPStatus.OK + assert len(response.data) > 0 + assert response.content_type == "application/json" diff --git a/integration_tests/test_api_gateway.py b/integration_tests/test_int_sync_api_gateway.py similarity index 96% rename from integration_tests/test_api_gateway.py rename to integration_tests/test_int_sync_api_gateway.py index 94ea701..c920aec 100644 --- a/integration_tests/test_api_gateway.py +++ b/integration_tests/test_int_sync_api_gateway.py @@ -1,3 +1,5 @@ +# This module is a copy paste of test_int_api_gateway.py + import pytest from clean_python import ctx diff --git a/integration_tests/test_api_provider.py b/integration_tests/test_int_sync_api_provider.py similarity index 89% rename from integration_tests/test_api_provider.py rename to integration_tests/test_int_sync_api_provider.py index 674d0c7..5e73932 100644 --- a/integration_tests/test_api_provider.py +++ b/integration_tests/test_int_sync_api_provider.py @@ -1,3 +1,5 @@ +# This module is a copy paste of test_int_api_provider.py + from http import HTTPStatus import pytest @@ -89,3 +91,11 @@ def test_urlencode(provider: SyncApiProvider): assert isinstance(response, dict) assert response["name"] == "x?" + + +def test_request_raw(provider: SyncApiProvider, book): + response = provider.request_raw("GET", f"v1/books/{book['id']}") + + assert response.status is HTTPStatus.OK + assert len(response.data) > 0 + assert response.content_type == "application/json" diff --git a/pyproject.toml b/pyproject.toml index 5e58c52..bb9a0f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ celery = ["pika"] fluentbit = ["fluent-logger"] sql = ["sqlalchemy==2.*", "asyncpg"] s3 = ["aioboto3", "boto3"] -api_client = ["urllib3"] +api_client = ["aiohttp", "urllib3"] profiler = ["yappi"] debugger = ["debugpy"] diff --git a/tests/api_client/test_api_gateway.py b/tests/api_client/test_api_gateway.py new file mode 100644 index 0000000..e212f72 --- /dev/null +++ b/tests/api_client/test_api_gateway.py @@ -0,0 +1,131 @@ +from http import HTTPStatus +from unittest import mock + +import pytest + +from clean_python import DoesNotExist +from clean_python import Json +from clean_python import Mapper +from clean_python.api_client import ApiException +from clean_python.api_client import ApiGateway +from clean_python.api_client import ApiProvider + + +class TstApiGateway(ApiGateway, path="foo/{id}"): + pass + + +@pytest.fixture +def api_provider(): + return mock.MagicMock(spec_set=ApiProvider) + + +@pytest.fixture +def api_gateway(api_provider) -> ApiGateway: + return TstApiGateway(api_provider) + + +async def test_get(api_gateway: ApiGateway): + actual = await api_gateway.get(14) + + api_gateway.provider.request.assert_called_once_with("GET", "foo/14") + assert actual is api_gateway.provider.request.return_value + + +async def test_add(api_gateway: ApiGateway): + actual = await api_gateway.add({"foo": 2}) + + api_gateway.provider.request.assert_called_once_with( + "POST", "foo/", json={"foo": 2} + ) + assert actual is api_gateway.provider.request.return_value + + +async def test_remove(api_gateway: ApiGateway): + actual = await api_gateway.remove(2) + + api_gateway.provider.request.assert_called_once_with("DELETE", "foo/2") + assert actual is True + + +async def test_remove_does_not_exist(api_gateway: ApiGateway): + api_gateway.provider.request.side_effect = ApiException( + {}, status=HTTPStatus.NOT_FOUND + ) + actual = await api_gateway.remove(2) + assert actual is False + + +async def test_update(api_gateway: ApiGateway): + actual = await api_gateway.update({"id": 2, "foo": "bar"}) + + api_gateway.provider.request.assert_called_once_with( + "PATCH", "foo/2", json={"foo": "bar"} + ) + assert actual is api_gateway.provider.request.return_value + + +async def test_update_no_id(api_gateway: ApiGateway): + with pytest.raises(DoesNotExist): + await api_gateway.update({"foo": "bar"}) + + assert not api_gateway.provider.request.called + + +async def test_update_does_not_exist(api_gateway: ApiGateway): + api_gateway.provider.request.side_effect = ApiException( + {}, status=HTTPStatus.NOT_FOUND + ) + with pytest.raises(DoesNotExist): + await api_gateway.update({"id": 2, "foo": "bar"}) + + +class TstMapper(Mapper): + def to_external(self, internal: Json) -> Json: + result = {} + if internal.get("id") is not None: + result["id"] = internal["id"] + if internal.get("name") is not None: + result["name"] = internal["name"].upper() + return result + + def to_internal(self, external: Json) -> Json: + return {"id": external["id"], "name": external["name"].lower()} + + +class TstMappedApiGateway(ApiGateway, path="foo/{id}"): + mapper = TstMapper() + + +@pytest.fixture +def mapped_api_gateway(api_provider) -> ApiGateway: + return TstMappedApiGateway(api_provider) + + +async def test_get_with_mapper(mapped_api_gateway: ApiGateway): + mapped_api_gateway.provider.request.return_value = {"id": 14, "name": "FOO"} + + assert await mapped_api_gateway.get(14) == {"id": 14, "name": "foo"} + + +async def test_add_with_mapper(mapped_api_gateway: ApiGateway): + mapped_api_gateway.provider.request.return_value = {"id": 3, "name": "FOO"} + + assert await mapped_api_gateway.add({"name": "foo"}) == {"id": 3, "name": "foo"} + + mapped_api_gateway.provider.request.assert_called_once_with( + "POST", "foo/", json={"name": "FOO"} + ) + + +async def test_update_with_mapper(mapped_api_gateway: ApiGateway): + mapped_api_gateway.provider.request.return_value = {"id": 2, "name": "BAR"} + + assert await mapped_api_gateway.update({"id": 2, "name": "bar"}) == { + "id": 2, + "name": "bar", + } + + mapped_api_gateway.provider.request.assert_called_once_with( + "PATCH", "foo/2", json={"name": "BAR"} + ) diff --git a/tests/api_client/test_api_provider.py b/tests/api_client/test_api_provider.py new file mode 100644 index 0000000..18beea9 --- /dev/null +++ b/tests/api_client/test_api_provider.py @@ -0,0 +1,144 @@ +from http import HTTPStatus +from unittest import mock + +import pytest +from aiohttp import ClientSession + +from clean_python import ctx +from clean_python import Tenant +from clean_python.api_client import ApiException +from clean_python.api_client import ApiProvider + +MODULE = "clean_python.api_client.api_provider" + + +@pytest.fixture +def tenant() -> Tenant: + ctx.tenant = Tenant(id=2, name="") + yield ctx.tenant + ctx.tenant = None + + +@pytest.fixture +def response(): + # this mocks the aiohttp.ClientResponse: + response = mock.Mock() + response.status = int(HTTPStatus.OK) + response.headers = {"Content-Type": "application/json"} + response.json = mock.AsyncMock(return_value={"foo": 2}) + response.read = mock.AsyncMock() + return response + + +@pytest.fixture +def api_provider(tenant, response) -> ApiProvider: + request = mock.AsyncMock() + with mock.patch.object(ClientSession, "request", new=request): + api_provider = ApiProvider( + url="http://testserver/foo/", + fetch_token=lambda a, b: f"tenant-{b}", + ) + api_provider._session.request.return_value = response + yield api_provider + + +async def test_get(api_provider: ApiProvider, response): + actual = await api_provider.request("GET", "") + + assert api_provider._session.request.call_count == 1 + assert api_provider._session.request.call_args[1] == dict( + method="GET", + url="http://testserver/foo", + headers={"Authorization": "Bearer tenant-2"}, + timeout=5.0, + data=None, + json=None, + ) + assert actual == {"foo": 2} + + +async def test_post_json(api_provider: ApiProvider, response): + response.status == int(HTTPStatus.CREATED) + api_provider._session.request.return_value = response + actual = await api_provider.request("POST", "bar", json={"foo": 2}) + + assert api_provider._session.request.call_count == 1 + + assert api_provider._session.request.call_args[1] == dict( + method="POST", + url="http://testserver/foo/bar", + data=None, + json={"foo": 2}, + headers={ + "Authorization": "Bearer tenant-2", + }, + timeout=5.0, + ) + assert actual == {"foo": 2} + + +@pytest.mark.parametrize( + "path,params,expected_url", + [ + ("", None, "http://testserver/foo"), + ("bar", None, "http://testserver/foo/bar"), + ("bar/", None, "http://testserver/foo/bar"), + ("", {"a": 2}, "http://testserver/foo?a=2"), + ("bar", {"a": 2}, "http://testserver/foo/bar?a=2"), + ("bar/", {"a": 2}, "http://testserver/foo/bar?a=2"), + ("", {"a": [1, 2]}, "http://testserver/foo?a=1&a=2"), + ("", {"a": 1, "b": "foo"}, "http://testserver/foo?a=1&b=foo"), + ], +) +async def test_url(api_provider: ApiProvider, path, params, expected_url): + await api_provider.request("GET", path, params=params) + assert api_provider._session.request.call_args[1]["url"] == expected_url + + +async def test_timeout(api_provider: ApiProvider): + await api_provider.request("POST", "bar", timeout=2.1) + assert api_provider._session.request.call_args[1]["timeout"] == 2.1 + + +@pytest.mark.parametrize( + "status", [HTTPStatus.OK, HTTPStatus.NOT_FOUND, HTTPStatus.INTERNAL_SERVER_ERROR] +) +async def test_unexpected_content_type(api_provider: ApiProvider, response, status): + response.status = int(status) + response.headers["Content-Type"] = "text/plain" + with pytest.raises(ApiException) as e: + await api_provider.request("GET", "bar") + + assert e.value.status is status + assert str(e.value) == f"{status}: Unexpected content type 'text/plain'" + + +async def test_json_variant_content_type(api_provider: ApiProvider, response): + response.headers["Content-Type"] = "application/something+json" + actual = await api_provider.request("GET", "bar") + assert actual == {"foo": 2} + + +async def test_no_content(api_provider: ApiProvider, response): + response.status = int(HTTPStatus.NO_CONTENT) + response.headers = {} + + actual = await api_provider.request("DELETE", "bar/2") + assert actual is None + + +@pytest.mark.parametrize("status", [HTTPStatus.BAD_REQUEST, HTTPStatus.NOT_FOUND]) +async def test_error_response(api_provider: ApiProvider, response, status): + response.status = int(status) + + with pytest.raises(ApiException) as e: + await api_provider.request("GET", "bar") + + assert e.value.status is status + assert str(e.value) == str(int(status)) + ": {'foo': 2}" + + +async def test_no_token(api_provider: ApiProvider): + api_provider._fetch_token = lambda a, b: None + await api_provider.request("GET", "") + assert api_provider._session.request.call_args[1]["headers"] == {} diff --git a/tests/api_client/test_sync_api_gateway.py b/tests/api_client/test_sync_api_gateway.py index 1d35644..76954b8 100644 --- a/tests/api_client/test_sync_api_gateway.py +++ b/tests/api_client/test_sync_api_gateway.py @@ -1,3 +1,5 @@ +# This module is a copy paste of test_api_gateway.py + from http import HTTPStatus from unittest import mock @@ -10,8 +12,6 @@ from clean_python.api_client import SyncApiGateway from clean_python.api_client import SyncApiProvider -MODULE = "clean_python.api_client.api_provider" - class TstSyncApiGateway(SyncApiGateway, path="foo/{id}"): pass diff --git a/tests/api_client/test_sync_api_provider.py b/tests/api_client/test_sync_api_provider.py index aeb3d5a..f782a21 100644 --- a/tests/api_client/test_sync_api_provider.py +++ b/tests/api_client/test_sync_api_provider.py @@ -1,3 +1,5 @@ +# This module is a copy paste of test_api_provider.py + from http import HTTPStatus from unittest import mock @@ -8,7 +10,7 @@ from clean_python.api_client import ApiException from clean_python.api_client import SyncApiProvider -MODULE = "clean_python.api_client.api_provider" +MODULE = "clean_python.api_client.sync_api_provider" @pytest.fixture