diff --git a/poetry.lock b/poetry.lock index b96b046..22b0126 100644 --- a/poetry.lock +++ b/poetry.lock @@ -642,6 +642,27 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "fakts" +version = "0.3.48" +description = "asynchronous configuration provider ( tailored to support dynamic client-server relations)" +category = "dev" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "fakts-0.3.48-py3-none-any.whl", hash = "sha256:c735d34086742f79366cb347647045268ffa5041425330462e727c43e6a022ed"}, + {file = "fakts-0.3.48.tar.gz", hash = "sha256:77e2e74d73db3f7b6ccc6e8dccfe2261a34eb98124d6feadf776c5e97cae5c7d"}, +] + +[package.dependencies] +koil = ">=0.3.5" +pydantic = ">1.8.2" +PyYAML = ">=5.2" +QtPy = ">=2.0.1,<3.0.0" + +[package.extras] +remote = ["aiohttp (>=3.8.2,<4.0.0)", "certifi (>2021)"] + [[package]] name = "frozenlist" version = "1.3.3" @@ -864,14 +885,14 @@ typing-extensions = ">=3.7.4.3" [[package]] name = "koil" -version = "0.3.3" +version = "0.3.5" description = "Async for a sync world" category = "main" optional = false python-versions = ">=3.7,<4.0" files = [ - {file = "koil-0.3.3-py3-none-any.whl", hash = "sha256:2e489497589d36b40c667d60ab0fdfb7ffc2c754e365bcc7a7a464e7e3927e5a"}, - {file = "koil-0.3.3.tar.gz", hash = "sha256:9f23da2babd4b7005cac211862bf5d54e21cb514170bf23c93f844f48fea5074"}, + {file = "koil-0.3.5-py3-none-any.whl", hash = "sha256:ea121a76fb3c08d19e57bc232e59845339a115b48b658c0542d461d4eeaa7222"}, + {file = "koil-0.3.5.tar.gz", hash = "sha256:dd07226dff7732188e0b75b14aac9e7a3485924f6ad0b801b703f9eb54f9b4d6"}, ] [package.dependencies] @@ -1272,6 +1293,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1279,8 +1301,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1297,6 +1326,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1304,6 +1334,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1917,9 +1948,10 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [extras] aiohttp = ["aiohttp", "certifi"] httpx = ["httpx"] +signing = ["cryptography"] websockets = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.7" -content-hash = "45288a4c2e2bf438efa192b7deda9880ef6123f9b4b027ed444dc9bed154c122" +content-hash = "a7dca388c56f8e798b342707f2cf04cf3836ecbd2b4e974d6adccdc3914f121b" diff --git a/pyproject.toml b/pyproject.toml index a06ade3..ee00b98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,11 +39,17 @@ pytest-cov = "^4.0.0" ruff = "^0.0.282" cryptography = "^41.0.3" pyjwt = "^2.8.0" +fakts = "^0.3.48" [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" +[tool.mypy] +exclude = ["venv/", "tests/"] +ignore_missing_imports = true + + [[tool.pydoc-markdown.loaders]] type = "python" search_path = ["rath"] diff --git a/rath/contrib/fakts/links/aiohttp.py b/rath/contrib/fakts/links/aiohttp.py index d9fd6e6..8d4594e 100644 --- a/rath/contrib/fakts/links/aiohttp.py +++ b/rath/contrib/fakts/links/aiohttp.py @@ -4,7 +4,7 @@ from fakts.fakt import Fakt from fakts.fakts import Fakts from rath.links.aiohttp import AIOHttpLink - +from rath.operation import Operation class AioHttpConfig(Fakt): """AioHttpConfig @@ -24,20 +24,21 @@ class FaktsAIOHttpLink(AIOHttpLink): """ fakts: Fakts - endpoint_url: Optional[str] + endpoint_url: Optional[str] # type: ignore fakts_group: str """ The fakts group within the fakts context to use for configuration """ - _old_fakt: Dict[str, Any] = None + _old_fakt: Optional[Dict[str, Any]] = None def configure(self, fakt: AioHttpConfig) -> None: """Configure the link with the given fakt""" self.endpoint_url = fakt.endpoint_url - async def aconnect(self, operation): + async def aconnect(self, operation: Operation): if self.fakts.has_changed(self._old_fakt, self.fakts_group): self._old_fakt = await self.fakts.aget(self.fakts_group) + assert self._old_fakt is not None, "Fakt should not be None" self.configure(AioHttpConfig(**self._old_fakt)) return await super().aconnect(operation) diff --git a/rath/contrib/fakts/links/graphql_ws.py b/rath/contrib/fakts/links/graphql_ws.py index 9fec903..a013115 100644 --- a/rath/contrib/fakts/links/graphql_ws.py +++ b/rath/contrib/fakts/links/graphql_ws.py @@ -14,7 +14,7 @@ class Config: class FaktsGraphQLWSLink(GraphQLWSLink): fakts: Fakts - ws_endpoint_url: Optional[str] + ws_endpoint_url: Optional[str] # type: ignore fakts_group: str = "websocket" _old_fakt: Dict[str, Any] = {} @@ -25,6 +25,7 @@ def configure(self, fakt: WebsocketHttpConfig) -> None: async def aconnect(self, operation: Any): if self.fakts.has_changed(self._old_fakt, self.fakts_group): self._old_fakt = await self.fakts.aget(self.fakts_group) + assert self._old_fakt is not None, "Fakt should not be None" self.configure(WebsocketHttpConfig(**self._old_fakt)) return await super().aconnect(operation) diff --git a/rath/contrib/fakts/links/httpx.py b/rath/contrib/fakts/links/httpx.py index 835f633..a642465 100644 --- a/rath/contrib/fakts/links/httpx.py +++ b/rath/contrib/fakts/links/httpx.py @@ -12,12 +12,12 @@ class Config: class FaktsHttpXLink(HttpxLink): - endpoint_url: Optional[str] + endpoint_url: Optional[str] # type: ignore fakts_group: str fakt: Optional[FaltsHttpXConfig] fakts: Fakts - _old_fakt: Dict[str, Any] = None + _old_fakt: Optional[Dict[str, Any]] = None def configure(self, fakt: FaltsHttpXConfig) -> None: self.endpoint_url = fakt.endpoint_url diff --git a/rath/contrib/fakts/links/subscription_transport_ws.py b/rath/contrib/fakts/links/subscription_transport_ws.py index 6c0bfe9..1414e01 100644 --- a/rath/contrib/fakts/links/subscription_transport_ws.py +++ b/rath/contrib/fakts/links/subscription_transport_ws.py @@ -14,10 +14,10 @@ class Config: class FaktsWebsocketLink(SubscriptionTransportWsLink): fakts: Fakts - ws_endpoint_url: Optional[str] + ws_endpoint_url: Optional[str] # type: ignore fakts_group: str = "websocket" - _old_fakt: Dict[str, Any] = {} + _old_fakt: Optional[Dict[str, Any]] = None def configure(self, fakt: WebsocketHttpConfig) -> None: self.ws_endpoint_url = fakt.ws_endpoint_url diff --git a/rath/errors.py b/rath/errors.py index b2fc64c..21cf8da 100644 --- a/rath/errors.py +++ b/rath/errors.py @@ -16,3 +16,10 @@ class NotEnteredError(RathException): to protected methods is attempted.""" pass + + +class NotComposedError(RathException): + """NotComposedError is raised when the Rath link chain is not composed and + the next link is accessed.""" + + pass \ No newline at end of file diff --git a/rath/links/aiohttp.py b/rath/links/aiohttp.py index 212c195..684272d 100644 --- a/rath/links/aiohttp.py +++ b/rath/links/aiohttp.py @@ -2,8 +2,8 @@ from http import HTTPStatus import json from ssl import SSLContext -from typing import Any, Dict, List, Type - +from typing import Any, Dict, List, Type, AsyncIterator +from rath.links.types import Payload import aiohttp from graphql import OperationType from pydantic import Field @@ -66,11 +66,11 @@ async def aconnect(self, operation: Operation): async def __aexit__(self, *args, **kwargs) -> None: pass - async def aexecute(self, operation: Operation) -> GraphQLResult: + async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: if not self._connected: await self.aconnect(operation) - payload = {"query": operation.document} + payload: Payload = {"query": operation.document} if operation.node.operation == OperationType.SUBSCRIPTION: raise NotImplementedError( diff --git a/rath/links/auth.py b/rath/links/auth.py index 1e63b6e..8a7890e 100644 --- a/rath/links/auth.py +++ b/rath/links/auth.py @@ -3,6 +3,7 @@ from rath.links.base import ContinuationLink from rath.operation import GraphQLResult, Operation from rath.links.errors import AuthenticationError +from rath.errors import NotComposedError async def fake_loader(): @@ -36,6 +37,9 @@ async def arefresh_token(self, operation: Operation): async def aexecute( self, operation: Operation, retry=0, **kwargs ) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise NotComposedError("No next link set") + token = await self.aload_token(operation) operation.context.headers["Authorization"] = f"Bearer {token}" operation.context.initial_payload["token"] = token diff --git a/rath/links/base.py b/rath/links/base.py index 69cc5ae..44f1a67 100644 --- a/rath/links/base.py +++ b/rath/links/base.py @@ -1,7 +1,7 @@ from typing import AsyncIterator, Optional from koil.composition import KoiledModel from rath.operation import GraphQLResult, Operation - +from rath.errors import NotComposedError class Link(KoiledModel): """A Link is a class that can be used to send operations to a GraphQL API. @@ -14,7 +14,7 @@ class Link(KoiledModel): """ - async def aconnect(self): + async def aconnect(self, operation: Operation): """A coroutine that is called when the link is connected.""" pass @@ -28,7 +28,7 @@ async def __aenter__(self) -> None: async def __aexit__(self, *args, **kwargs) -> None: pass - def aexecute(self, operation: Operation, **kwargs) -> AsyncIterator[GraphQLResult]: + def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: """A coroutine that takes an operation and returns an AsyncIterator of GraphQLResults. This method should be implemented by subclasses.""" raise NotImplementedError( @@ -56,7 +56,7 @@ class AsyncTerminatingLink(TerminatingLink): TerminatingLink (_type_): _description_ """ - async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: + def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: raise NotImplementedError("Your Async Transport needs to overwrite this method") @@ -71,6 +71,9 @@ class ContinuationLink(Link): def set_next(self, next: Link): self.next = next - async def aexecute(self, operation: Operation, **kwargs) -> GraphQLResult: - async for x in self.next.aexecute(operation, **kwargs): + async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise NotComposedError("No next link set") + + async for x in self.next.aexecute(operation): yield x diff --git a/rath/links/compose.py b/rath/links/compose.py index 8871a0e..ece12f2 100644 --- a/rath/links/compose.py +++ b/rath/links/compose.py @@ -1,9 +1,9 @@ -from typing import List +from typing import List, Optional from pydantic import validator from rath.links.base import ContinuationLink, Link, TerminatingLink from rath.operation import Operation - +from rath.errors import NotComposedError class ComposedLink(TerminatingLink): """A composed link is a link that is composed of multiple links. The links @@ -69,7 +69,7 @@ class TypedComposedLink(TerminatingLink): automatically composed together. """ - _firstlink: Link = None + _firstlink: Optional[Link] = None async def __aenter__(self): current_link = None @@ -93,6 +93,10 @@ async def __aexit__(self, *args, **kwargs): await link.__aexit__(*args, **kwargs) async def aexecute(self, operation: Operation, **kwargs): + if not self._firstlink: + raise NotComposedError("Links need to be composed before they can be executed. (Through __aenter__)") + + async for result in self._firstlink.aexecute(operation): yield result diff --git a/rath/links/dictinglink.py b/rath/links/dictinglink.py index 9e3024b..8b0153e 100644 --- a/rath/links/dictinglink.py +++ b/rath/links/dictinglink.py @@ -9,6 +9,20 @@ def parse_variables( variables: Dict, by_alias: bool = True, ) -> Dict: + """Parse Variables + + Parse vaiables converts any pydantic models in the variables dict to dicts + by calling their .json() method with by_alias=True + + Args: + variables (Dict): variables to parse + by_alias (bool, optional): whether to use the alias names. Defaults to True. + + Returns: + Dict: the parsed variables + """ + + def recurse_extract(obj): """ recursively traverse obj, doing a deepcopy, but @@ -43,7 +57,7 @@ class DictingLink(ParsingLink): It traversed the variables dict, and converts any (nested) pydantic models to dicts by callind their .json() method.""" - by_alias = True + by_alias: bool = True """Converts pydantic models to dicts by calling their .json() method with by_alias=True""" async def aparse(self, operation: Operation) -> Operation: diff --git a/rath/links/errors.py b/rath/links/errors.py index e4d2fd7..ab71fa9 100644 --- a/rath/links/errors.py +++ b/rath/links/errors.py @@ -18,12 +18,18 @@ def __init__(self, message) -> None: class TerminatingLinkError(LinkError): - """Raised when a terminating link is called.""" + """Raised when a terminating link is called. + + This is a base class for all terminating link errors.""" class ContinuationLinkError(LinkError): + """Raised when a continuation link is called an errors. + + THis is a base class for all continuation link errors.""" pass class AuthenticationError(TerminatingLinkError): + """Signals that the authentication failed.""" pass diff --git a/rath/links/file.py b/rath/links/file.py index 6190560..404287b 100644 --- a/rath/links/file.py +++ b/rath/links/file.py @@ -1,10 +1,13 @@ +from rath.errors import NotComposedError from rath.links.base import ContinuationLink +from rath.operation import GraphQLResult, Operation from rath.operation import Operation -from rath.operation import Operation +from typing import AsyncIterator import io import aiohttp from typing import AsyncGenerator +from pydantic import Field FILE_CLASSES = ( @@ -17,10 +20,25 @@ from typing import Any, Dict, Tuple, Type -def parse_variables( +def parse_nested_files( variables: Dict, file_classes: Tuple[Type[Any], ...] = FILE_CLASSES, ) -> Tuple[Dict, Dict]: + """Parse nested files + + Parameters + ---------- + variables : Dict + The variables to parse + file_classes : Tuple[Type[Any], ...], optional + File-like classes to extract, by default FILE_CLASSES + + Returns + ------- + Tuple[Dict, Dict] + The parsed variables and the extracted files + """ + files = {} def recurse_extract(path, obj): @@ -69,10 +87,17 @@ class FileExtraction(ContinuationLink): These can then be used by the FileUploadLink to upload the files to a remote server. or used through the multipart/form-data encoding in the terminating link (if supported). """ + file_classes: Tuple[Type[Any], ...] = Field(default=FILE_CLASSES) + + async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise NotComposedError( + "FileExtractionLink must be composed with another link" + ) + - async def aexecute(self, operation: Operation) -> Operation: - operation.variables, operation.context.files = parse_variables( - operation.variables + operation.variables, operation.context.files = parse_nested_files( + operation.variables, file_classes=self.file_classes ) async for result in self.next.aexecute(operation): diff --git a/rath/links/foward.py b/rath/links/foward.py index 9298488..50c85fd 100644 --- a/rath/links/foward.py +++ b/rath/links/foward.py @@ -4,6 +4,7 @@ from rath.links.base import ContinuationLink from rath.operation import GraphQLResult, Operation from rath.links.errors import AuthenticationError +from rath.errors import NotComposedError class ForwardLink(ContinuationLink): @@ -17,6 +18,9 @@ class ForwardLink(ContinuationLink): async def aexecute( self, operation: Operation, **kwargs ) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise NotComposedError("No next link set") + async for result in self.next.aexecute(operation, **kwargs): yield result diff --git a/rath/links/graphql_ws.py b/rath/links/graphql_ws.py index 778f994..95f4a47 100644 --- a/rath/links/graphql_ws.py +++ b/rath/links/graphql_ws.py @@ -1,5 +1,5 @@ from ssl import SSLContext -from typing import Awaitable, Callable, Dict, Optional +from typing import Awaitable, Callable, Dict, Optional, Any from graphql import OperationType from pydantic import Field import websockets @@ -64,6 +64,11 @@ async def default_pong_handler(payload): return payload +InitialConnectPayload = Dict[str, Any] +PongPayload = Dict[str, Any] + + + class GraphQLWSLink(AsyncTerminatingLink): """GraphQLWSLink is a terminating link that sends operations over websockets using websockets via the graphql-ws protocol. This is a @@ -86,11 +91,14 @@ class GraphQLWSLink(AsyncTerminatingLink): default_factory=lambda: ssl.create_default_context(cafile=certifi.where()) ) - on_connect: Optional[Callable[[], Awaitable[None]]] = Field(exclude=True) - on_pong: Optional[Callable[[], Awaitable[None]]] = Field( + on_connect: Optional[Callable[[InitialConnectPayload], Awaitable[None]]] = Field(exclude=True) + """ A function that is called before the connection is established. If an exception is raised, the connection is not established. Return is ignored.""" + + + on_pong: Optional[Callable[[PongPayload], Awaitable[None]]] = Field( default=default_pong_handler, exclude=True ) - """ A function that is called when the connection is established """ + """ A function that is called before a pong is received. If an exception is raised, the connection is not established. Return is ignored.""" heartbeat_interval_ms: Optional[int] = None """ The heartbeat interval in milliseconds (None means no heartbeats are being send) """ @@ -98,8 +106,8 @@ class GraphQLWSLink(AsyncTerminatingLink): _connection_lock: Optional[asyncio.Lock] = None _connected: bool = False _alive: bool = False - _send_queue: asyncio.Queue = None - _connection_task: asyncio.Task = None + _send_queue: Optional[asyncio.Queue] = None + _connection_task: Optional[asyncio.Task] = None _ongoing_subscriptions: Optional[Dict[str, asyncio.Queue]] = None async def aforward(self, message): @@ -172,7 +180,13 @@ async def websocket_loop( task.cancel() for task in done: - raise task.exception() + exception = task.exception() + if exception: + raise exception + else: + raise CorrectableConnectionFail( + f"Websocket connection closed without exception: This is unexpected behaviours. Results ist {task.result()}" + ) except Exception as e: logger.warning("Websocket excepted. Trying to recover", exc_info=True) @@ -188,8 +202,8 @@ async def websocket_loop( await asyncio.sleep(self.time_between_retries) logger.info(f"Retrying to connect") - await self.broadcast({"type": WEBSOCKET_DEAD, "error": e}) - await self.websocket_loop(retry=retry + 1) + await self.broadcast({"type": WEBSOCKET_DEAD, "error": e}, initial_connection_future) + await self.websocket_loop(initiating_operation, initial_connection_future, retry=retry + 1) except DefiniteConnectionFail as e: logger.error("Websocket excepted closed definetely", exc_info=True) @@ -202,9 +216,9 @@ async def websocket_loop( send_task.cancel() receive_task.cancel() - cancellation = await asyncio.gather( + await asyncio.gather( send_task, receive_task, return_exceptions=True - ) + ) # wait for the tasks to finish raise e except Exception as e: @@ -221,6 +235,9 @@ async def sending(self, client, initiating_operation: Operation): try: while True: + if not self._send_queue: + raise LinkNotConnectedError("Link is not connected") + message = await self._send_queue.get() logger.debug("GraphQL Websocket: >>>>>> " + message) await client.send(message) @@ -257,9 +274,9 @@ async def broadcast(self, message: dict, initial_connection_future: asyncio.Futu if type == GQL_PING: if self.on_pong: - payload = await self.on_pong(message.get("payload", {})) - else: - payload = message.get("payload", {}) + await self.on_pong(message.get("payload", {})) + + payload = message.get("payload", {}) await self.aforward(json.dumps({"type": GQL_PONG, "payload": payload})) if type == GQL_CONNECTION_KEEP_ALIVE: @@ -267,11 +284,17 @@ async def broadcast(self, message: dict, initial_connection_future: asyncio.Futu if type == WEBSOCKET_DEAD: # notify all subscriptipns that the websocket is dead + if not self._ongoing_subscriptions: + self._ongoing_subscriptions = {} + for subscription in self._ongoing_subscriptions.values(): await subscription.put(message) return if type in [GQL_DATA, GQL_COMPLETE, GQL_ERROR]: + if not self._ongoing_subscriptions: + self._ongoing_subscriptions = {} + if "id" not in message: raise InvalidPayload(f"Protocol Violation. Expected 'id' in {message}") @@ -282,6 +305,9 @@ async def broadcast(self, message: dict, initial_connection_future: asyncio.Futu await self._ongoing_subscriptions[id].put(message) async def aexecute(self, operation: Operation): + if not self._connection_lock: + raise LinkNotConnectedError("Link is not connected") + async with self._connection_lock: if self._connection_task is None or self._connection_task.done(): # we need to start a new connection @@ -293,17 +319,20 @@ async def aexecute(self, operation: Operation): assert not operation.context.files, "We cannot send files through websockets" id = operation.id - subscribe_queue = asyncio.Queue() + subscribe_queue = asyncio.Queue() #type: ignore + if not self._ongoing_subscriptions: + self._ongoing_subscriptions = {} + self._ongoing_subscriptions[id] = subscribe_queue - payload = { + send_payload = { "headers": operation.context.headers, "query": operation.document, "variables": operation.variables, } try: - frame = {"id": id, "type": GQL_START, "payload": payload} + frame = {"id": id, "type": GQL_START, "payload": send_payload} await self.aforward(json.dumps(frame)) logger.debug(f"Subcription started {operation}") diff --git a/rath/links/httpx.py b/rath/links/httpx.py index 4cd9af3..16b3023 100644 --- a/rath/links/httpx.py +++ b/rath/links/httpx.py @@ -1,6 +1,6 @@ from http import HTTPStatus import json -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Type, AsyncIterator import httpx from graphql import OperationType from pydantic import Field @@ -8,10 +8,19 @@ from rath.links.base import AsyncTerminatingLink from rath.links.errors import AuthenticationError import logging - +from rath.links.types import Payload +from datetime import datetime logger = logging.getLogger(__name__) +class DateTimeEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, datetime): + return o.isoformat() + + return json.JSONEncoder.default(self, o) + + class HttpxLink(AsyncTerminatingLink): """HttpxLink is a terminating link that sends operations over HTTP using httpx""" @@ -22,17 +31,12 @@ class HttpxLink(AsyncTerminatingLink): default_factory=lambda: (HTTPStatus.FORBIDDEN,) ) """auth_errors is a list of HTTPStatus codes that indicate an authentication error.""" + json_encoder: Type[json.JSONEncoder] = Field(default=DateTimeEncoder, exclude=True) - _client = None - async def __aenter__(self) -> None: - self._client = await httpx.AsyncClient().__aenter__() + async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: - async def __aexit__(self, *args, **kwargs) -> None: - await self._client.__aexit__(*args, **kwargs) - - async def aexecute(self, operation: Operation) -> GraphQLResult: - payload = {"query": operation.document} + payload: Payload = {"query": operation.document} if operation.node.operation == OperationType.SUBSCRIPTION: raise NotImplementedError( @@ -64,27 +68,29 @@ async def aexecute(self, operation: Operation) -> GraphQLResult: payload["variables"] = operation.variables post_kwargs = {"json": payload} - response = await self._client.post( - self.endpoint_url, headers=operation.context.headers, **post_kwargs - ) + async with httpx.AsyncClient() as client: - if response.status_code in self.auth_errors: - raise AuthenticationError( - f"Token Expired Error {operation.context.headers}" + response = await client.post( + self.endpoint_url, headers=operation.context.headers, **post_kwargs ) - if response.status_code == HTTPStatus.OK: - json_response = response.json() - - if "errors" in json_response: - raise GraphQLException( - "\n".join([e["message"] for e in json_response["errors"]]) + if response.status_code in self.auth_errors: + raise AuthenticationError( + f"Token Expired Error {operation.context.headers}" ) - if "data" not in json_response: - raise Exception(f"Response does not contain data {json_response}") + if response.status_code == HTTPStatus.OK: + json_response = response.json() + + if "errors" in json_response: + raise GraphQLException( + "\n".join([e["message"] for e in json_response["errors"]]) + ) + + if "data" not in json_response: + raise Exception(f"Response does not contain data {json_response}") - yield GraphQLResult(data=json_response["data"]) + yield GraphQLResult(data=json_response["data"]) class Config: arbitrary_types_allowed = True diff --git a/rath/links/log.py b/rath/links/log.py index 0221024..d7b6776 100644 --- a/rath/links/log.py +++ b/rath/links/log.py @@ -2,7 +2,7 @@ from rath.links.base import ContinuationLink from rath.operation import GraphQLResult, Operation - +from rath.errors import NotComposedError async def just_print(operation: Operation): print(operation) @@ -19,6 +19,9 @@ class LogLink(ContinuationLink): async def aexecute( self, operation: Operation, **kwargs ) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise NotComposedError("No next link set") + await self.log(operation) async for result in self.next.aexecute(operation, **kwargs): yield result diff --git a/rath/links/parsing.py b/rath/links/parsing.py index 3d77c86..5c57644 100644 --- a/rath/links/parsing.py +++ b/rath/links/parsing.py @@ -1,6 +1,7 @@ from rath.links.base import ContinuationLink -from rath.operation import Operation - +from rath.operation import GraphQLResult, Operation +from typing import AsyncIterator, Awaitable, Callable +from rath.errors import NotComposedError class ParsingLink(ContinuationLink): """ParsingLink is a link that parses operation and returns a new operation. @@ -10,7 +11,10 @@ class ParsingLink(ContinuationLink): async def aparse(self, operation: Operation) -> Operation: raise NotImplementedError("Please implement this method") - async def aexecute(self, operation: Operation, **kwargs) -> Operation: + async def aexecute(self, operation: Operation, **kwargs) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise NotComposedError("No next link set") + operation = await self.aparse(operation) async for result in self.next.aexecute(operation, **kwargs): yield result diff --git a/rath/links/retry.py b/rath/links/retry.py index 9397c02..e41b69a 100644 --- a/rath/links/retry.py +++ b/rath/links/retry.py @@ -11,6 +11,7 @@ from rath.links.errors import AuthenticationError import logging import asyncio +from rath.errors import NotComposedError logger = logging.getLogger(__name__) @@ -26,10 +27,13 @@ class RetryLink(ContinuationLink): """The number of seconds to wait before retrying the operation.""" async def aexecute( - self, operation: Operation, retry=0, **kwargs + self, operation: Operation, retry=0 ) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise NotComposedError("No next link set") + try: - async for result in self.next.aexecute(operation, **kwargs): + async for result in self.next.aexecute(operation): yield result except SubscriptionDisconnect as e: @@ -41,7 +45,7 @@ async def aexecute( await asyncio.sleep(self.sleep_interval) logger.info(f"Subscription {operation} disconnected. Retrying {retry}") - async for result in self.aexecute(operation, retry=retry + 1, **kwargs): + async for result in self.aexecute(operation, retry=retry + 1): yield result class Config: diff --git a/rath/links/sign_local_link.py b/rath/links/sign_local_link.py index 7174de1..f469ff5 100644 --- a/rath/links/sign_local_link.py +++ b/rath/links/sign_local_link.py @@ -10,6 +10,13 @@ class SignLocalLink(AuthTokenLink): + """SignLocalLink + + SignLocalLink is a link is a type of AuthTokenLink that + crated a JWT token using a local private key, + and sends it to the next link. + + """ private_key: rsa.RSAPrivateKey @validator("private_key", pre=True, always=True) @@ -47,6 +54,16 @@ async def arefresh_token(self): class ComposedSignTokenLink(SignLocalLink): + """ComposedSignTokenLink + + ComposedSignTokenLink is a SignLocalLink that + uses a payload retriever to retrieve the payload + to sign, enables the user to use a custom payload + retriever., without having to implement the entire + SignLocalLink. + """ + + payload_retriever: Callable[[Operation], Awaitable[Dict]] async def aretrieve_payload(self, operation: Operation): diff --git a/rath/links/split.py b/rath/links/split.py index fe2af3a..88af993 100644 --- a/rath/links/split.py +++ b/rath/links/split.py @@ -1,7 +1,7 @@ -from typing import Callable +from typing import Callable, AsyncIterator from pydantic import Field -from rath.operation import Operation +from rath.operation import Operation, GraphQLResult from rath.links.base import TerminatingLink @@ -19,7 +19,7 @@ class SplitLink(TerminatingLink): split: Callable[[Operation], bool] = Field(exclude=True) """The function used to split the operation. This function should return a boolean. If true, the operation is sent to the left path, otherwise to the right path.""" - async def aexecute(self, operation: Operation, **kwargs) -> Operation: + async def aexecute(self, operation: Operation, **kwargs) -> AsyncIterator[GraphQLResult]: iterator = ( self.left.aexecute(operation, **kwargs) if self.split(operation) diff --git a/rath/links/subscription_transport_ws.py b/rath/links/subscription_transport_ws.py index 5127c8f..123e970 100644 --- a/rath/links/subscription_transport_ws.py +++ b/rath/links/subscription_transport_ws.py @@ -9,7 +9,9 @@ import logging import ssl import certifi -from rath.links.errors import TerminatingLinkError +from rath.links.errors import LinkNotConnectedError, TerminatingLinkError + + from rath.operation import ( GraphQLException, @@ -83,11 +85,11 @@ class SubscriptionTransportWsLink(AsyncTerminatingLink): """Should the payload token be sent as a querystring instead (as connection params is not supported by all servers)""" - _connection_lock: asyncio.Lock = None + _connection_lock: Optional[asyncio.Lock] = None _connected: bool = False _alive: bool = False - _send_queue: asyncio.Queue = None - _connection_task: asyncio.Task = None + _send_queue: Optional[asyncio.Queue] = None + _connection_task: Optional[asyncio.Task] = None _ongoing_subscriptions: Optional[Dict[str, asyncio.Queue]] = None async def aforward(self, message): @@ -165,7 +167,13 @@ async def websocket_loop( task.cancel() for task in done: - raise task.exception() + exception = task.exception() + if exception: + raise exception + else: + raise CorrectableConnectionFail( + f"Websocket connection closed without exception: This is unexpected behaviours. Results ist {task.result()}" + ) except Exception as e: logger.warning("Websocket excepted. Trying to recover", exc_info=True) @@ -181,8 +189,8 @@ async def websocket_loop( await asyncio.sleep(self.time_between_retries) logger.info(f"Retrying to connect") - await self.broadcast({"type": WEBSOCKET_DEAD, "error": e}) - await self.websocket_loop(retry=retry + 1) + await self.broadcast({"type": WEBSOCKET_DEAD, "error": e}, connection_future) + await self.websocket_loop(initiating_operation, connection_future, retry=retry + 1) except DefiniteConnectionFail as e: logger.error("Websocket excepted closed definetely", exc_info=True) @@ -213,11 +221,15 @@ async def sending(self, client, initiating_operation: Operation): payload = { "type": GQL_CONNECTION_INIT, "payload": {"headers": initiating_operation.context.headers}, + } await client.send(json.dumps(payload)) try: while True: + if not self._send_queue: + raise LinkNotConnectedError("Link is not connected") + message = await self._send_queue.get() logger.debug("GraphQL Websocket: >>>>>> " + message) await client.send(message) @@ -260,6 +272,10 @@ async def broadcast(self, message: dict, connection_future: asyncio.Future): if type == WEBSOCKET_DEAD: # notify all subscriptipns that the websocket is dead + if not self._ongoing_subscriptions: + self._ongoing_subscriptions = {} + + for subscription in self._ongoing_subscriptions.values(): await subscription.put(message) return @@ -269,12 +285,19 @@ async def broadcast(self, message: dict, connection_future: asyncio.Future): raise InvalidPayload(f"Protocol Violation. Expected 'id' in {message}") id = message["id"] + if not self._ongoing_subscriptions: + self._ongoing_subscriptions = {} + + assert ( id in self._ongoing_subscriptions ), "Received Result for subscription that is no longer or was never active" await self._ongoing_subscriptions[id].put(message) async def aexecute(self, operation: Operation): + if not self._connection_lock: + raise Exception("WebsocketLink not entered yet. Please us this in an async context manager") + async with self._connection_lock: if self._connection_task is None or self._connection_task.done(): await self.aconnect(operation) @@ -285,17 +308,21 @@ async def aexecute(self, operation: Operation): assert not operation.context.files, "We cannot send files through websockets" id = operation.id - subscribe_queue = asyncio.Queue() + subscribe_queue = asyncio.Queue() # type: asyncio.Queue + + if not self._ongoing_subscriptions: + self._ongoing_subscriptions = {} + self._ongoing_subscriptions[id] = subscribe_queue - payload = { + send_payload = { "headers": operation.context.headers, "query": operation.document, "variables": operation.variables, } try: - frame = {"id": id, "type": GQL_START, "payload": payload} + frame = {"id": id, "type": GQL_START, "payload": send_payload} await self.aforward(json.dumps(frame)) logger.debug(f"Subcription started {operation}") diff --git a/rath/links/testing/mock.py b/rath/links/testing/mock.py index e16d5f4..b783fb4 100644 --- a/rath/links/testing/mock.py +++ b/rath/links/testing/mock.py @@ -1,5 +1,5 @@ import asyncio -from typing import AsyncIterator, Awaitable, Callable, Dict +from typing import AsyncIterator, Awaitable, Callable, Dict, AsyncGenerator from pydantic import Field, validator from rath.links.base import AsyncTerminatingLink @@ -28,7 +28,7 @@ class AsyncMockLink(AsyncTerminatingLink): mutation_resolver: Dict[str, Callable[[Operation], Awaitable[Dict]]] = Field( default_factory=dict, exclude=True ) - subscription_resolver: Dict[str, Callable[[Operation], Awaitable[Dict]]] = Field( + subscription_resolver: Dict[str, Callable[[Operation], AsyncIterator[Dict]]] = Field( default_factory=dict, exclude=True ) resolver: Dict[str, Callable[[Operation], Awaitable[Dict]]] = Field( @@ -54,20 +54,21 @@ async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: futures = [] for op in operation.node.selection_set.selections: - if op.name.value in self.query_resolver: - futures.append(self.query_resolver[op.name.value](operation)) - elif op.name.value in self.resolver: - futures.append(self.resolver[op.name.value](operation)) - else: - raise NotImplementedError( - f"Mocked Resolver for Query '{op.name.value}' not in resolvers: {self.query_resolver}, {self.resolver} for AsyncMockLink" - ) + if isinstance(op, FieldNode): + if op.name.value in self.query_resolver: + futures.append(self.query_resolver[op.name.value](operation)) + elif op.name.value in self.resolver: + futures.append(self.resolver[op.name.value](operation)) + else: + raise NotImplementedError( + f"Mocked Resolver for Query '{op.name.value}' not in resolvers: {self.query_resolver}, {self.resolver} for AsyncMockLink" + ) resolved = await asyncio.gather(*futures) yield GraphQLResult( data={ target_from_node(op): resolved[i] - for i, op in enumerate(operation.node.selection_set.selections) + for i, op in enumerate(operation.node.selection_set.selections) if isinstance(op, FieldNode) } ) @@ -75,20 +76,21 @@ async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: futures = [] for op in operation.node.selection_set.selections: - if op.name.value in self.mutation_resolver: - futures.append(self.mutation_resolver[op.name.value](operation)) - elif op.name.value in self.resolver: - futures.append(self.resolver[op.name.value](operation)) - else: - raise NotImplementedError( - f"Mocked Resolver for Query '{op.name.value}' not in resolvers: {self.mutation_resolver}, {self.resolver} for AsyncMockLink" - ) + if isinstance(op, FieldNode): + if op.name.value in self.mutation_resolver: + futures.append(self.mutation_resolver[op.name.value](operation)) + elif op.name.value in self.resolver: + futures.append(self.resolver[op.name.value](operation)) + else: + raise NotImplementedError( + f"Mocked Resolver for Query '{op.name.value}' not in resolvers: {self.mutation_resolver}, {self.resolver} for AsyncMockLink" + ) resolved = await asyncio.gather(*futures) yield GraphQLResult( data={ target_from_node(op): resolved[i] - for i, op in enumerate(operation.node.selection_set.selections) + for i, op in enumerate(operation.node.selection_set.selections) if isinstance(op, FieldNode) } ) @@ -99,17 +101,22 @@ async def aexecute(self, operation: Operation) -> AsyncIterator[GraphQLResult]: ), "Only one Subscription at a time possible" op = operation.node.selection_set.selections[0] - if op.name.value in self.subscription_resolver: - iterator = self.subscription_resolver[op.name.value](operation) - elif op.name.value in self.resolver: - iterator = self.resolver[op.name.value](operation) - else: - raise NotImplementedError( - f"Mocked Resolver for Query '{op.name.value}' not in resolvers: {self.subscription_resolver}, {self.resolver} for AsyncMockLink" - ) + if isinstance(op, FieldNode): + if op.name.value in self.subscription_resolver: + iterator = self.subscription_resolver[op.name.value](operation) + + async for event in iterator: + if isinstance(op, FieldNode): + yield GraphQLResult(data={target_from_node(op): event}) + else: + raise NotImplementedError( + f"Mocked Resolver for Query '{op.name.value}' not in resolvers: {self.subscription_resolver}, {self.resolver} for AsyncMockLink" + ) + async for event in iterator: - yield GraphQLResult(data={target_from_node(op): event}) + if isinstance(op, FieldNode): + yield GraphQLResult(data={target_from_node(op): event}) else: raise NotImplementedError("Only subscription are mocked") diff --git a/rath/links/transpile.py b/rath/links/transpile.py index 4a33786..d909232 100644 --- a/rath/links/transpile.py +++ b/rath/links/transpile.py @@ -6,6 +6,7 @@ NonNullTypeNode, OperationDefinitionNode, VariableNode, + TypeNode ) from pydantic import BaseModel, Field from rath.links.parsing import ParsingLink @@ -157,7 +158,7 @@ def decorator(func): def recurse_transpile( key, - var: VariableNode, + var: TypeNode, value: Any, registry: TranspileRegistry, in_list=0, @@ -219,9 +220,12 @@ def recurse_transpile( else: if var.name.value in registry.item_handlers: - for k, handler in registry.item_handlers[var.name.value].items(): + type_handlers = registry.item_handlers[var.name.value] + + + for key, item_handler in type_handlers.items(): try: - predicate = handler.predicate(value) + predicate = item_handler.predicate(value) except Exception as e: if strict: raise Exception(f"Handler {handler} failed with {e}") @@ -230,7 +234,7 @@ def recurse_transpile( ) continue if predicate: - parsed_value = [handler.parser(value) for value in value] + parsed_value = [item_handler.parser(value) for value in value] assert ( parsed_value is not None ), f"Handler {handler} failed on parsing {value}. Please check your parser for edge cases" @@ -273,7 +277,7 @@ def transpile( transpiled_variables = { key: recurse_transpile(key, variable, variables[key], registry, strict=strict) - for key, variable in variable_nodes.items() + for key, variable in variable_nodes.items() if isinstance(variable, TypeNode) } return transpiled_variables diff --git a/rath/links/types.py b/rath/links/types.py new file mode 100644 index 0000000..a9851b8 --- /dev/null +++ b/rath/links/types.py @@ -0,0 +1,4 @@ +from typing import Dict, Any, Optional, Union, TypeAlias + + +Payload = Dict[str, Any] \ No newline at end of file diff --git a/rath/links/utils.py b/rath/links/utils.py index 606de0a..8e0d82f 100644 --- a/rath/links/utils.py +++ b/rath/links/utils.py @@ -1,11 +1,11 @@ -from typing import Dict, Any, Callable +from typing import Dict, Any, Callable, Optional from rath.operation import Operation def recurse_parse_variables( variables: Dict, - predicate: Callable[[str, Any], bool], + predicate: Callable[[Optional[str], Any], bool], apply: Callable[[Any], Any], ) -> Dict: """Parse Variables @@ -22,7 +22,7 @@ def recurse_parse_variables( Dict: _description_ """ - def recurse_extract(obj, path: str = None): + def recurse_extract(obj, path: Optional[str] = None): """ recursively traverse obj, doing a deepcopy, but replacing any file-like objects with nulls and @@ -30,18 +30,18 @@ def recurse_extract(obj, path: str = None): """ if isinstance(obj, list): - nulled_obj = [] + nulled_list = [] for key, value in enumerate(obj): value = recurse_extract( value, - f"{path}.{key}" if path else key, + path=f"{path}.{key}" if path else str(key), ) - nulled_obj.append(value) - return nulled_obj + nulled_list.append(value) + return nulled_list elif isinstance(obj, dict): nulled_obj = {} for key, value in obj.items(): - value = recurse_extract(value, f"{path}.{key}" if path else key) + value = recurse_extract(value, f"{path}.{key}" if path else str(key)) nulled_obj[key] = value return nulled_obj elif predicate(path, obj): @@ -57,7 +57,7 @@ def recurse_extract(obj, path: str = None): def recurse_parse_variables_with_operation( variables: Dict, operation: Operation, - predicate: Callable[[str, Any], bool], + predicate: Callable[[Optional[str], Any], bool], apply: Callable[[Any], Any], ) -> Dict: """Parse Variables @@ -74,7 +74,7 @@ def recurse_parse_variables_with_operation( Dict: _description_ """ - def recurse_extract(obj, path: str = None): + def recurse_extract(obj, path: Optional[str] = None): """ recursively traverse obj, doing a deepcopy, but replacing any file-like objects with nulls and @@ -82,18 +82,18 @@ def recurse_extract(obj, path: str = None): """ if isinstance(obj, list): - nulled_obj = [] + nulled_list = [] for key, value in enumerate(obj): value = recurse_extract( value, - f"{path}.{key}" if path else key, + path=f"{path}.{key}" if path else str(key), ) - nulled_obj.append(value) - return nulled_obj + nulled_list.append(value) + return nulled_list elif isinstance(obj, dict): nulled_obj = {} for key, value in obj.items(): - value = recurse_extract(value, f"{path}.{key}" if path else key) + value = recurse_extract(value, f"{path}.{key}" if path else str(key)) nulled_obj[key] = value return nulled_obj elif predicate(path, obj): diff --git a/rath/links/validate.py b/rath/links/validate.py index a81fdbc..ff60ca6 100644 --- a/rath/links/validate.py +++ b/rath/links/validate.py @@ -1,10 +1,11 @@ -from typing import AsyncIterator, Optional +from typing import AsyncIterator, Optional, cast from graphql import ( GraphQLSchema, build_ast_schema, build_client_schema, get_introspection_query, validate, + IntrospectionQuery ) from graphql.language.parser import parse from pydantic import root_validator @@ -72,33 +73,38 @@ def check_schema_dsl_or_schema_glob(cls, values): ) return values - - async def aload_schema(self, operation: Operation) -> None: - assert self.allow_introspection, "Introspection is not allowed" + + async def introspect(self, starting_operation: Operation) -> GraphQLSchema: #type: ignore + if not self.next: + raise ContinuationLinkError("No next link set") + introspect_operation = opify(get_introspection_query()) - introspect_operation.context = operation.context - introspect_operation.extensions = operation.extensions + introspect_operation.context = starting_operation.context + introspect_operation.extensions = starting_operation.extensions - async for e in self.next.aexecute(introspect_operation): - self.graphql_schema = build_client_schema(e.data) - return + async for result in self.next.aexecute(introspect_operation): + return build_client_schema(cast(IntrospectionQuery, result.data)) + - def validate(self, operation: Operation): - errors = validate(self.graphql_schema, operation.document_node) + async def aexecute( + self, operation: Operation, **kwargs + ) -> AsyncIterator[GraphQLResult]: + if not self.next: + raise ContinuationLinkError("No next link set") + + if not self.graphql_schema: + assert self.allow_introspection, "Introspection is not allowed" + self.graphql_schema = await self.introspect(operation) + + errors = validate(self.graphql_schema, operation.document_node) if len(errors) > 0: raise ValidationError( f"{operation} does not comply with the schema!\n Errors: \n\n" + "\n".join([e.message for e in errors]) ) - async def aexecute( - self, operation: Operation, **kwargs - ) -> AsyncIterator[GraphQLResult]: - if not self.graphql_schema: - await self.aload_schema(operation) - self.validate(operation) async for result in self.next.aexecute(operation, **kwargs): yield result diff --git a/rath/operation.py b/rath/operation.py index ada8003..8b3b5b8 100644 --- a/rath/operation.py +++ b/rath/operation.py @@ -1,9 +1,10 @@ from typing import Optional, Dict, Any, Union -from graphql.language import OperationDefinitionNode, parse, OperationType +from graphql.language import OperationDefinitionNode, parse, OperationType, print_ast from graphql import ( DocumentNode, get_operation_ast, parse, + ) from pydantic import BaseModel, Field import uuid @@ -12,10 +13,10 @@ class Context(BaseModel): """Context provides a way to pass arbitrary data to resolvers on the context""" - headers: Optional[Dict[str, str]] = Field(default_factory=dict) - files: Optional[Dict[str, Any]] = Field(default_factory=dict) - initial_payload: Optional[Dict[str, Any]] = Field(default_factory=dict) - kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict) + headers: Dict[str, str] = Field(default_factory=dict) + files: Dict[str, Any] = Field(default_factory=dict) + initial_payload: Dict[str, Any]= Field(default_factory=dict) + kwargs: Dict[str, Any] = Field(default_factory=dict) class Extensions(BaseModel): @@ -65,8 +66,8 @@ class SubscriptionDisconnect(GraphQLException): def opify( query: Union[str, DocumentNode], - variables: Dict[str, Any] = None, - headers: Dict[str, Any] = None, + variables: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, **kwargs, ) -> Operation: @@ -82,15 +83,15 @@ def opify( Operation: A GraphQL operation """ - document = parse(query) if query and isinstance(query, str) else query + document = parse(query) if isinstance(query, str) else query op = get_operation_ast(document, operation_name) assert op, f"No operation named {operation_name}" return Operation( node=op, - document=query, + document=print_ast(document), document_node=document, variables=variables or {}, operation_name=operation_name, - extensions={}, + extensions=Extensions(), context=Context(headers=headers or {}, kwargs=kwargs), ) diff --git a/rath/rath.py b/rath/rath.py index 80802ce..d7b0640 100644 --- a/rath/rath.py +++ b/rath/rath.py @@ -18,7 +18,7 @@ from koil import unkoil_gen, unkoil -current_rath = ContextVar("current_rath_unpicklable") +current_rath: ContextVar["Rath"] = ContextVar("current_rath_unpicklable") class Rath(KoiledModel): @@ -49,7 +49,7 @@ class Rath(KoiledModel): """ - link: Optional[TerminatingLink] = None + link: TerminatingLink = Field(..., description="The terminating link used to send operations to the server. Can be a composed link chain.") """The terminating link used to send operations to the server. Can be a composed link chain.""" _connected = False @@ -60,9 +60,9 @@ class Rath(KoiledModel): async def aquery( self, query: Union[str, DocumentNode], - variables: Dict[str, Any] = None, - headers: Dict[str, Any] = None, - operation_name: str = None, + variables: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, **kwargs, ) -> GraphQLResult: """Query the GraphQL API. @@ -86,15 +86,28 @@ async def aquery( """ op = opify(query, variables, headers, operation_name, **kwargs) + result = None + async for data in self.link.aexecute(op): - return data + result = data + break + + if not result: + raise NotConnectedError("Could not retrieve data from the server.") + # This is to account for the fact that mypy apparently doesn't + # understand that a return statement in a generator is valid. + # This is a workaround to make mypy happy. + + return result + + def query( self, query: Union[str, DocumentNode], - variables: Dict[str, Any] = None, - headers: Dict[str, Any] = None, - operation_name: str = None, + variables: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, **kwargs, ) -> GraphQLResult: """Query the GraphQL API. @@ -121,9 +134,9 @@ def query( def subscribe( self, query: str, - variables: Dict[str, Any] = None, - headers: Dict[str, Any] = None, - operation_name: str = None, + variables: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, **kwargs, ) -> Iterator[GraphQLResult]: """Subscripe to a GraphQL API. @@ -151,9 +164,9 @@ def subscribe( async def asubscribe( self, query: str, - variables: Dict[str, Any] = None, - headers: Dict[str, Any] = {}, - operation_name=None, + variables: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, Any]] = None, + operation_name: Optional[str]=None, **kwargs, ) -> AsyncIterator[GraphQLResult]: """Subscripe to a GraphQL API. diff --git a/rath/turms/funcs.py b/rath/turms/funcs.py index 03818bc..50ba9f9 100644 --- a/rath/turms/funcs.py +++ b/rath/turms/funcs.py @@ -1,25 +1,26 @@ from rath.rath import Rath, current_rath +from typing import Optional -def execute(operation, variables, rath: Rath = None): +def execute(operation, variables, rath: Optional[Rath] = None): rath = rath or current_rath.get() return operation(**rath.query(operation.Meta.document, variables).data) -async def aexecute(operation, variables, rath: Rath = None): +async def aexecute(operation, variables, rath: Optional[Rath] = None): rath = rath or current_rath.get() x = await rath.aquery(operation.Meta.document, variables) return operation(**x.data) -def subscribe(operation, variables, rath: Rath = None): +def subscribe(operation, variables, rath: Optional[Rath] = None): rath = rath or current_rath.get() for event in rath.subscribe(operation.Meta.document, variables): yield operation(**event.data) -async def asubscribe(operation, variables, rath: Rath = None): +async def asubscribe(operation, variables, rath: Optional[Rath] = None): rath = rath or current_rath.get() async for event in rath.asubscribe(operation.Meta.document, variables): yield operation(**event.data)