Skip to content

Commit

Permalink
[handlers] remove duplicate logic
Browse files Browse the repository at this point in the history
  • Loading branch information
david-lev committed May 30, 2024
1 parent 1211f29 commit f8b9737
Showing 1 changed file with 62 additions and 73 deletions.
135 changes: 62 additions & 73 deletions pywa/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def factory_filter(


async def _get_factored_update(
handler: CallbackButtonHandler | CallbackSelectionHandler | MessageStatusHandler,
handler: _FactoryHandler,
wa: WhatsApp,
update: CallbackButton | CallbackSelection | MessageStatus,
field_name: str,
Expand Down Expand Up @@ -319,7 +319,44 @@ def __init__(
super().__init__(callback, *filters)


class CallbackButtonHandler(Handler):
class _FactoryHandler(Handler):
"""Base class for handlers that use a factory to construct the callback data."""

_field_name = "messages"
_data_field: str

def __init__(
self,
callback: Callable[[WhatsApp, Any], Any | Awaitable[Any]],
*filters: Callable[[WhatsApp, Any], bool | Awaitable[bool]],
factory: _CallbackDataFactoryT = str,
factory_before_filters: bool = False,
):
(
self.factory,
self.factory_filter,
) = _resolve_factory(factory, self._data_field)
self.factory_before_filters = factory_before_filters
super().__init__(callback, *filters)

async def handle(self, wa: WhatsApp, data: Any):
update = await _get_factored_update(self, wa, data, self._data_field)
if update is not None:
if inspect.iscoroutinefunction(self.callback):
await self.callback(wa, update)
else:
await wa._loop.run_in_executor(
wa._executor,
self.callback,
wa,
update,
)

def __str__(self) -> str:
return f"{self.__class__.__name__}(callback={self.callback!r}, filters={self.filters!r}, factory={self.factory!r})"


class CallbackButtonHandler(_FactoryHandler):
"""
Handler for callback buttons (User clicks on a :class:`pywa.types.Button`).
Expand All @@ -342,7 +379,7 @@ class CallbackButtonHandler(Handler):
filters will get the callback data after the factory is applied).
"""

_field_name = "messages"
_data_field = "data"

def __init__(
self,
Expand All @@ -351,31 +388,15 @@ def __init__(
factory: _CallbackDataFactoryT = str,
factory_before_filters: bool = False,
):
(
self.factory,
self.factory_filter,
) = _resolve_factory(factory, "data")
self.factory_before_filters = factory_before_filters
super().__init__(callback, *filters)

async def handle(self, wa: WhatsApp, clb: CallbackButton):
update = await _get_factored_update(self, wa, clb, "data")
if update is not None:
if inspect.iscoroutinefunction(self.callback):
await self.callback(wa, update)
else:
await wa._loop.run_in_executor(
wa._executor,
self.callback,
wa,
update,
)

def __str__(self) -> str:
return f"{self.__class__.__name__}(callback={self.callback!r}, filters={self.filters!r}, factory={self.factory!r})"
super().__init__(
callback,
*filters,
factory=factory,
factory_before_filters=factory_before_filters,
)


class CallbackSelectionHandler(Handler):
class CallbackSelectionHandler(_FactoryHandler):
"""
Handler for callback selections (User selects an option from :class:`pywa.types.SectionList`).
Expand All @@ -399,7 +420,7 @@ class CallbackSelectionHandler(Handler):
filters will get the callback data after the factory is applied).
"""

_field_name = "messages"
_data_field = "data"

def __init__(
self,
Expand All @@ -408,31 +429,15 @@ def __init__(
factory: _CallbackDataFactoryT = str,
factory_before_filters: bool = False,
):
(
self.factory,
self.factory_filter,
) = _resolve_factory(factory, "data")
self.factory_before_filters = factory_before_filters
super().__init__(callback, *filters)

async def handle(self, wa: WhatsApp, sel: CallbackSelection):
update = await _get_factored_update(self, wa, sel, "data")
if update is not None:
if inspect.iscoroutinefunction(self.callback):
await self.callback(wa, update)
else:
await wa._loop.run_in_executor(
wa._executor,
self.callback,
wa,
update,
)

def __str__(self) -> str:
return f"{self.__class__.__name__}(callback={self.callback!r}, filters={self.filters!r}, factory={self.factory!r})"
super().__init__(
callback,
*filters,
factory=factory,
factory_before_filters=factory_before_filters,
)


class MessageStatusHandler(Handler):
class MessageStatusHandler(_FactoryHandler):
"""
Handler for :class:`pywa.types.MessageStatus` updates (Message is sent, delivered, read, failed, etc...).
Expand All @@ -458,7 +463,7 @@ class MessageStatusHandler(Handler):
filters will get the tracker data after the factory is applied).
"""

_field_name = "messages"
_data_field = "tracker"

def __init__(
self,
Expand All @@ -467,28 +472,12 @@ def __init__(
factory: _CallbackDataFactoryT = str,
factory_before_filters: bool = False,
):
(
self.factory,
self.factory_filter,
) = _resolve_factory(factory, "tracker")
self.factory_before_filters = factory_before_filters
super().__init__(callback, *filters)

async def handle(self, wa: WhatsApp, status: MessageStatus):
update = await _get_factored_update(self, wa, status, "tracker")
if update is not None:
if inspect.iscoroutinefunction(self.callback):
await self.callback(wa, update)
else:
await wa._loop.run_in_executor(
wa._executor,
self.callback,
wa,
update,
)

def __str__(self) -> str:
return f"{self.__class__.__name__}(callback={self.callback!r}, filters={self.filters!r}, factory={self.factory!r})"
super().__init__(
callback,
*filters,
factory=factory,
factory_before_filters=factory_before_filters,
)


class ChatOpenedHandler(Handler):
Expand Down

0 comments on commit f8b9737

Please sign in to comment.