Skip to content

Commit

Permalink
[handlers] faster flowrequest handler resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
david-lev committed Aug 20, 2024
1 parent acd0a84 commit 66a4e89
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 77 deletions.
130 changes: 73 additions & 57 deletions pywa/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import abc
import asyncio
import collections
import dataclasses
import functools
import inspect
Expand All @@ -61,7 +62,7 @@
FlowResponseError,
FlowRequestActionType,
FlowRequestCannotBeDecrypted,
FlowTokenNoLongerValid,
Screen,
) # noqa

if TYPE_CHECKING:
Expand Down Expand Up @@ -200,10 +201,9 @@ async def _get_factored_update(
),
*handler.filters,
):
if inspect.iscoroutinefunction(f):
if not await f(wa, update):
return
elif not f(wa, update):
if not (
await f(wa, update) if inspect.iscoroutinefunction(f) else f(wa, update)
):
return
except AttributeError as e:
if (
Expand Down Expand Up @@ -244,7 +244,7 @@ def _field_name(self) -> str | None:
def __init__(
self,
callback: Callable[[WhatsApp, Any], Any],
*filters: Callable[[WhatsApp, Any], bool],
*filters: Callable[[WhatsApp, Any], bool | Awaitable[bool]],
priority: int,
):
"""
Expand All @@ -255,17 +255,15 @@ def __init__(
self.priority = priority

async def handle(self, wa: WhatsApp, data: Any) -> bool:
for f in self.filters:
if inspect.iscoroutinefunction(f):
if not await f(wa, data):
return False
elif not f(wa, data):
return False
if not all(
await f(wa, data) if inspect.iscoroutinefunction(f) else f(wa, data)
for f in self.filters
):
return False

if inspect.iscoroutinefunction(self.callback):
await self.callback(wa, data)
else:
self.callback(wa, data)
await self.callback(wa, data) if inspect.iscoroutinefunction(
self.callback
) else self.callback(wa, data)
return True

@staticmethod
Expand Down Expand Up @@ -353,10 +351,9 @@ def __init__(
async def handle(self, wa: WhatsApp, data: Any) -> bool:
update = await _get_factored_update(self, wa, data, self._data_field)
if update is not None:
if inspect.iscoroutinefunction(self.callback):
await self.callback(wa, update)
else:
self.callback(wa, update)
await self.callback(wa, update) if inspect.iscoroutinefunction(
self.callback
) else self.callback(wa, update)
return True
return False

Expand Down Expand Up @@ -1100,14 +1097,17 @@ def __init__(
self._endpoint = endpoint
self._main_callback = callback
self._error_callback: _FlowRequestHandlerT | None = None
self._on_callbacks = list[
tuple[
FlowRequestActionType | str,
str | None,
Callable[["WhatsApp", dict | None], bool] | None,
_FlowRequestHandlerT,
self._on_callbacks: dict[
tuple[FlowRequestActionType | str, str | None],
list[
tuple[
Callable[["WhatsApp", dict | None], bool | Awaitable[bool]] | None,
_FlowRequestHandlerT,
]
],
]() # [(action, screen?, data_filter?, callback), ...]
] = collections.defaultdict(
list
) # {(action, screen?): [(data_filter?, callback)]}
self._acknowledge_errors = acknowledge_errors
self._handle_health_check = handle_health_check
self._private_key = private_key or wa._private_key
Expand Down Expand Up @@ -1148,7 +1148,7 @@ def on(
self,
*,
action: FlowRequestActionType,
screen: str | None = None,
screen: Screen | str | None = None,
data_filter: Callable[[WhatsApp, dict | None], bool] | None = None,
) -> Callable[[_FlowRequestHandlerT], _FlowRequestHandlerT]:
"""
Expand Down Expand Up @@ -1218,7 +1218,7 @@ def add_handler(
*,
callback: _FlowRequestHandlerT,
action: FlowRequestActionType,
screen: str | None = None,
screen: Screen | str | None = None,
data_filter: Callable[[WhatsApp, dict | None], bool] | None = None,
) -> FlowRequestCallbackWrapper:
"""
Expand Down Expand Up @@ -1248,7 +1248,9 @@ def add_handler(
Returns:
The current instance.
"""
self._on_callbacks.append((action, screen, data_filter, callback))
self._on_callbacks[
(action, screen.id if isinstance(screen, Screen) else screen)
].append((data_filter, callback))
return self

def set_errors_handler(
Expand Down Expand Up @@ -1283,7 +1285,34 @@ def set_errors_handler(
self._error_callback = callback
return self

async def _get_callback(self, req: FlowRequest) -> _FlowRequestHandlerT:
"""Resolve the callback to use for the incoming request."""
if req.has_error and self._error_callback:
return self._error_callback
for data_filter, callback in (
*self._on_callbacks[(req.action, None)], # No screen priority
*self._on_callbacks[(req.action, req.screen)],
):
if data_filter is None or (
await data_filter(self._wa, req.data)
if asyncio.iscoroutinefunction(data_filter)
else data_filter(self._wa, req.data)
):
return callback
return self._main_callback

async def __call__(self, payload: dict) -> tuple[str, int]:
"""
Handle the incoming request.
- This method is called automatically by pywa, or manually when using custom server.
Args:
payload: The incoming request payload.
Returns:
A tuple containing the response data and the status code.
"""
try:
decrypted_request, aes_key, iv = self._request_decryptor(
payload["encrypted_flow_data"],
Expand Down Expand Up @@ -1314,9 +1343,7 @@ async def __call__(self, payload: dict) -> tuple[str, int]:
iv,
), 200
try:
request = FlowRequest.from_dict(
data=decrypted_request, raw_encrypted=payload
)
req = FlowRequest.from_dict(data=decrypted_request, raw_encrypted=payload)
except Exception:
_logger.exception(
"Flow Endpoint ('%s'): Failed to construct FlowRequest from decrypted data: %s",
Expand All @@ -1326,26 +1353,15 @@ async def __call__(self, payload: dict) -> tuple[str, int]:

return "Failed to construct FlowRequest", 500

if request.has_error and self._error_callback:
callback = self._error_callback
else:
for action, screen, data_filter, callback in self._on_callbacks:
if action == request.action and (
screen is None or screen == request.screen
):
if data_filter is None or data_filter(self._wa, request.data):
callback = callback
break
else:
callback = self._main_callback

callback = await self._get_callback(req)
try:
if asyncio.iscoroutinefunction(callback):
response = await callback(self._wa, request)
else:
response = callback(self._wa, request)
if isinstance(response, FlowResponseError):
raise response
res = (
await callback(self._wa, req)
if asyncio.iscoroutinefunction(callback)
else callback(self._wa, req)
)
if isinstance(res, FlowResponseError):
raise res
except FlowResponseError as e:
return (
self._response_encryptor(
Expand All @@ -1363,24 +1379,24 @@ async def __call__(self, payload: dict) -> tuple[str, int]:
)
return "An error occurred", 500

if self._acknowledge_errors and request.has_error:
if self._acknowledge_errors and req.has_error:
return self._response_encryptor(
{
"version": request.version,
"version": req.version,
"data": {
"acknowledged": True,
},
},
aes_key,
iv,
), 200
if not isinstance(response, (FlowResponse | dict)):
if not isinstance(res, (FlowResponse | dict)):
raise TypeError(
f"Flow endpoint ('{self._endpoint}') callback ('{callback.__name__}') must return a `FlowResponse`"
f" or `dict`, not {type(response)}"
f" or `dict`, not {type(res)}"
)
return self._response_encryptor(
response.to_dict() if isinstance(response, FlowResponse) else response,
res.to_dict() if isinstance(res, FlowResponse) else res,
aes_key,
iv,
), 200
14 changes: 10 additions & 4 deletions pywa/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,6 @@ def __init__(
validate_updates: bool,
):
self._server = server
if server is utils.MISSING:
return
self._server_type = utils.ServerType.from_app(server)
self._verify_token = verify_token
self._webhook_endpoint = webhook_endpoint
self._private_key = business_private_key
Expand All @@ -116,6 +113,10 @@ def __init__(
self._skip_duplicate_updates = skip_duplicate_updates
self._updates_ids_in_process = set[str]()

if server is utils.MISSING:
return
self._server_type = utils.ServerType.from_app(server)

if not verify_token:
raise ValueError(
"When listening for incoming updates, a verify token must be provided.\n>> The verify token can "
Expand Down Expand Up @@ -580,9 +581,14 @@ def _register_flow_endpoint_callback(
) -> handlers.FlowRequestCallbackWrapper:
"""Internal function to register a flow endpoint callback."""
if self._server is None:
raise ValueError(
"When using a custom server, you must use the `get_flow_request_handler` method to get the flow "
"request handler and call it manually."
)
elif self._server is utils.MISSING:
raise ValueError(
"You must initialize the WhatsApp client with an web server"
" (Flask or FastAPI) in order to handle incoming flow requests."
f" ({utils.ServerType.protocols_names()}) in order to handle incoming flow requests."
)

callback_wrapper = self.get_flow_request_handler(
Expand Down
17 changes: 1 addition & 16 deletions pywa/types/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import dataclasses
import datetime
from typing import TYPE_CHECKING, Any, Callable, Iterable
from typing import TYPE_CHECKING, Iterable

from ..errors import WhatsAppError

Expand All @@ -32,21 +32,6 @@
from ..client import WhatsApp


_FIELDS_TO_OBJECTS_CONSTRUCTORS: dict[str, Callable[[dict, WhatsApp], Any]] = dict(
text=lambda m, _client: m["body"],
image=Image.from_dict,
video=Video.from_dict,
sticker=Sticker.from_dict,
document=Document.from_dict,
audio=Audio.from_dict,
reaction=Reaction.from_dict,
location=Location.from_dict,
contacts=lambda m, _client: tuple(Contact.from_dict(c) for c in m),
order=Order.from_dict,
system=System.from_dict,
)


@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class Message(BaseUserUpdate):
"""
Expand Down
Loading

0 comments on commit 66a4e89

Please sign in to comment.