diff --git a/pywa/handlers.py b/pywa/handlers.py index faa2303..c9d6227 100644 --- a/pywa/handlers.py +++ b/pywa/handlers.py @@ -38,6 +38,7 @@ import abc import asyncio +import collections import dataclasses import functools import inspect @@ -61,7 +62,7 @@ FlowResponseError, FlowRequestActionType, FlowRequestCannotBeDecrypted, - FlowTokenNoLongerValid, + Screen, ) # noqa if TYPE_CHECKING: @@ -200,10 +201,9 @@ async def _get_factored_update( ), *handler.filters, ): - if inspect.iscoroutinefunction(f): - if not await f(wa, update): - return - elif not f(wa, update): + if not ( + await f(wa, update) if inspect.iscoroutinefunction(f) else f(wa, update) + ): return except AttributeError as e: if ( @@ -244,7 +244,7 @@ def _field_name(self) -> str | None: def __init__( self, callback: Callable[[WhatsApp, Any], Any], - *filters: Callable[[WhatsApp, Any], bool], + *filters: Callable[[WhatsApp, Any], bool | Awaitable[bool]], priority: int, ): """ @@ -255,17 +255,15 @@ def __init__( self.priority = priority async def handle(self, wa: WhatsApp, data: Any) -> bool: - for f in self.filters: - if inspect.iscoroutinefunction(f): - if not await f(wa, data): - return False - elif not f(wa, data): - return False + if not all( + await f(wa, data) if inspect.iscoroutinefunction(f) else f(wa, data) + for f in self.filters + ): + return False - if inspect.iscoroutinefunction(self.callback): - await self.callback(wa, data) - else: - self.callback(wa, data) + await self.callback(wa, data) if inspect.iscoroutinefunction( + self.callback + ) else self.callback(wa, data) return True @staticmethod @@ -353,10 +351,9 @@ def __init__( async def handle(self, wa: WhatsApp, data: Any) -> bool: update = await _get_factored_update(self, wa, data, self._data_field) if update is not None: - if inspect.iscoroutinefunction(self.callback): - await self.callback(wa, update) - else: - self.callback(wa, update) + await self.callback(wa, update) if inspect.iscoroutinefunction( + self.callback + ) else self.callback(wa, update) return True return False @@ -1100,14 +1097,17 @@ def __init__( self._endpoint = endpoint self._main_callback = callback self._error_callback: _FlowRequestHandlerT | None = None - self._on_callbacks = list[ - tuple[ - FlowRequestActionType | str, - str | None, - Callable[["WhatsApp", dict | None], bool] | None, - _FlowRequestHandlerT, + self._on_callbacks: dict[ + tuple[FlowRequestActionType | str, str | None], + list[ + tuple[ + Callable[["WhatsApp", dict | None], bool | Awaitable[bool]] | None, + _FlowRequestHandlerT, + ] ], - ]() # [(action, screen?, data_filter?, callback), ...] + ] = collections.defaultdict( + list + ) # {(action, screen?): [(data_filter?, callback)]} self._acknowledge_errors = acknowledge_errors self._handle_health_check = handle_health_check self._private_key = private_key or wa._private_key @@ -1148,7 +1148,7 @@ def on( self, *, action: FlowRequestActionType, - screen: str | None = None, + screen: Screen | str | None = None, data_filter: Callable[[WhatsApp, dict | None], bool] | None = None, ) -> Callable[[_FlowRequestHandlerT], _FlowRequestHandlerT]: """ @@ -1218,7 +1218,7 @@ def add_handler( *, callback: _FlowRequestHandlerT, action: FlowRequestActionType, - screen: str | None = None, + screen: Screen | str | None = None, data_filter: Callable[[WhatsApp, dict | None], bool] | None = None, ) -> FlowRequestCallbackWrapper: """ @@ -1248,7 +1248,9 @@ def add_handler( Returns: The current instance. """ - self._on_callbacks.append((action, screen, data_filter, callback)) + self._on_callbacks[ + (action, screen.id if isinstance(screen, Screen) else screen) + ].append((data_filter, callback)) return self def set_errors_handler( @@ -1283,7 +1285,34 @@ def set_errors_handler( self._error_callback = callback return self + async def _get_callback(self, req: FlowRequest) -> _FlowRequestHandlerT: + """Resolve the callback to use for the incoming request.""" + if req.has_error and self._error_callback: + return self._error_callback + for data_filter, callback in ( + *self._on_callbacks[(req.action, None)], # No screen priority + *self._on_callbacks[(req.action, req.screen)], + ): + if data_filter is None or ( + await data_filter(self._wa, req.data) + if asyncio.iscoroutinefunction(data_filter) + else data_filter(self._wa, req.data) + ): + return callback + return self._main_callback + async def __call__(self, payload: dict) -> tuple[str, int]: + """ + Handle the incoming request. + + - This method is called automatically by pywa, or manually when using custom server. + + Args: + payload: The incoming request payload. + + Returns: + A tuple containing the response data and the status code. + """ try: decrypted_request, aes_key, iv = self._request_decryptor( payload["encrypted_flow_data"], @@ -1314,9 +1343,7 @@ async def __call__(self, payload: dict) -> tuple[str, int]: iv, ), 200 try: - request = FlowRequest.from_dict( - data=decrypted_request, raw_encrypted=payload - ) + req = FlowRequest.from_dict(data=decrypted_request, raw_encrypted=payload) except Exception: _logger.exception( "Flow Endpoint ('%s'): Failed to construct FlowRequest from decrypted data: %s", @@ -1326,26 +1353,15 @@ async def __call__(self, payload: dict) -> tuple[str, int]: return "Failed to construct FlowRequest", 500 - if request.has_error and self._error_callback: - callback = self._error_callback - else: - for action, screen, data_filter, callback in self._on_callbacks: - if action == request.action and ( - screen is None or screen == request.screen - ): - if data_filter is None or data_filter(self._wa, request.data): - callback = callback - break - else: - callback = self._main_callback - + callback = await self._get_callback(req) try: - if asyncio.iscoroutinefunction(callback): - response = await callback(self._wa, request) - else: - response = callback(self._wa, request) - if isinstance(response, FlowResponseError): - raise response + res = ( + await callback(self._wa, req) + if asyncio.iscoroutinefunction(callback) + else callback(self._wa, req) + ) + if isinstance(res, FlowResponseError): + raise res except FlowResponseError as e: return ( self._response_encryptor( @@ -1363,10 +1379,10 @@ async def __call__(self, payload: dict) -> tuple[str, int]: ) return "An error occurred", 500 - if self._acknowledge_errors and request.has_error: + if self._acknowledge_errors and req.has_error: return self._response_encryptor( { - "version": request.version, + "version": req.version, "data": { "acknowledged": True, }, @@ -1374,13 +1390,13 @@ async def __call__(self, payload: dict) -> tuple[str, int]: aes_key, iv, ), 200 - if not isinstance(response, (FlowResponse | dict)): + if not isinstance(res, (FlowResponse | dict)): raise TypeError( f"Flow endpoint ('{self._endpoint}') callback ('{callback.__name__}') must return a `FlowResponse`" - f" or `dict`, not {type(response)}" + f" or `dict`, not {type(res)}" ) return self._response_encryptor( - response.to_dict() if isinstance(response, FlowResponse) else response, + res.to_dict() if isinstance(res, FlowResponse) else res, aes_key, iv, ), 200 diff --git a/pywa/server.py b/pywa/server.py index 55ecb1c..9f24825 100644 --- a/pywa/server.py +++ b/pywa/server.py @@ -100,9 +100,6 @@ def __init__( validate_updates: bool, ): self._server = server - if server is utils.MISSING: - return - self._server_type = utils.ServerType.from_app(server) self._verify_token = verify_token self._webhook_endpoint = webhook_endpoint self._private_key = business_private_key @@ -116,6 +113,10 @@ def __init__( self._skip_duplicate_updates = skip_duplicate_updates self._updates_ids_in_process = set[str]() + if server is utils.MISSING: + return + self._server_type = utils.ServerType.from_app(server) + if not verify_token: raise ValueError( "When listening for incoming updates, a verify token must be provided.\n>> The verify token can " @@ -580,9 +581,14 @@ def _register_flow_endpoint_callback( ) -> handlers.FlowRequestCallbackWrapper: """Internal function to register a flow endpoint callback.""" if self._server is None: + raise ValueError( + "When using a custom server, you must use the `get_flow_request_handler` method to get the flow " + "request handler and call it manually." + ) + elif self._server is utils.MISSING: raise ValueError( "You must initialize the WhatsApp client with an web server" - " (Flask or FastAPI) in order to handle incoming flow requests." + f" ({utils.ServerType.protocols_names()}) in order to handle incoming flow requests." ) callback_wrapper = self.get_flow_request_handler( diff --git a/pywa/types/message.py b/pywa/types/message.py index d5616e7..62a6795 100644 --- a/pywa/types/message.py +++ b/pywa/types/message.py @@ -8,7 +8,7 @@ import dataclasses import datetime -from typing import TYPE_CHECKING, Any, Callable, Iterable +from typing import TYPE_CHECKING, Iterable from ..errors import WhatsAppError @@ -32,21 +32,6 @@ from ..client import WhatsApp -_FIELDS_TO_OBJECTS_CONSTRUCTORS: dict[str, Callable[[dict, WhatsApp], Any]] = dict( - text=lambda m, _client: m["body"], - image=Image.from_dict, - video=Video.from_dict, - sticker=Sticker.from_dict, - document=Document.from_dict, - audio=Audio.from_dict, - reaction=Reaction.from_dict, - location=Location.from_dict, - contacts=lambda m, _client: tuple(Contact.from_dict(c) for c in m), - order=Order.from_dict, - system=System.from_dict, -) - - @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class Message(BaseUserUpdate): """ diff --git a/tests/test_flows.py b/tests/test_flows.py index 1a382e4..90b31ac 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -1,7 +1,9 @@ +import dataclasses import json import pytest +from pywa import WhatsApp, handlers, utils from pywa.types.flows import ( FlowJSON, Screen, @@ -29,8 +31,11 @@ FormRef, ScreenData, FlowResponse, + FlowRequest, + FlowRequestActionType, ) from pywa.utils import Version +from tests import common FLOWS_VERSION = "2.1" @@ -1420,3 +1425,75 @@ def test_flow_response_with_data_sources(): data={"data_source": [DataSource(id="1", title="Example")]}, screen="TEST", ).to_dict()["data"]["data_source"] == [{"id": "1", "title": "Example"}] + + +@pytest.mark.asyncio +async def test_flow_callback_wrapper(): + wa = WhatsApp( + token="xxx", server=None, business_private_key="xxx", verify_token="fdfd" + ) + + def main_handler(_, __): ... + + req = FlowRequest( + version=..., + action=FlowRequestActionType.DATA_EXCHANGE, + flow_token="xyz", + screen="START", + data={}, + raw=..., + raw_encrypted=..., + ) + wrapper = wa.get_flow_request_handler(endpoint="/flow", callback=main_handler) + assert await wrapper._get_callback(req) is main_handler + + def data_exchange_start_screen_callback(_, __): ... + + wrapper.add_handler( + callback=data_exchange_start_screen_callback, + action=FlowRequestActionType.DATA_EXCHANGE, + screen="START", + ) + req = dataclasses.replace(req, screen="START") + assert await wrapper._get_callback(req) is data_exchange_start_screen_callback + + def data_exchange_callback_without_screen(_, __): ... + + wrapper.add_handler( + callback=data_exchange_callback_without_screen, + action=FlowRequestActionType.DATA_EXCHANGE, + screen=None, + ) + assert await wrapper._get_callback(req) is data_exchange_callback_without_screen + + def init_with_data_filter(_, __): ... + + wrapper._on_callbacks.clear() + wrapper.add_handler( + callback=init_with_data_filter, + action=FlowRequestActionType.INIT, + screen=None, + data_filter=lambda _, data: data.get("age") >= 20, + ) + req = dataclasses.replace(req, action=FlowRequestActionType.INIT, data={"age": 20}) + assert await wrapper._get_callback(req) is init_with_data_filter + + +def test_flows_server(): + with pytest.raises(ValueError, match="^When using a custom server.*"): + wa = WhatsApp(token=..., server=None, verify_token=...) + wa.add_flow_request_handler( + handlers.FlowRequestHandler( + callback=..., + endpoint=..., + ) + ) + + with pytest.raises(ValueError, match="^You must initialize the WhatsApp client.*"): + wa = WhatsApp(token=..., server=utils.MISSING) + wa.add_flow_request_handler( + handlers.FlowRequestHandler( + callback=..., + endpoint=..., + ) + )