Skip to content

Commit

Permalink
fix(commands): fix application command checks' typing (#1048)
Browse files Browse the repository at this point in the history
Signed-off-by: lena <[email protected]>
Co-authored-by: shiftinv <[email protected]>
  • Loading branch information
elenakrittik and shiftinv authored Nov 4, 2023
1 parent 25d1d7a commit 59b101f
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 31 deletions.
1 change: 1 addition & 0 deletions changelog/1045.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Fix incorrect typings of :meth:`~disnake.ext.commands.InvokableApplicationCommand.add_check`, :meth:`~disnake.ext.commands.InvokableApplicationCommand.remove_check`, :meth:`~disnake.ext.commands.InteractionBotBase.add_app_command_check` and :meth:`~disnake.ext.commands.InteractionBotBase.remove_app_command_check`.
1 change: 1 addition & 0 deletions changelog/1045.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
|commands| Implement :func:`~disnake.ext.commands.app_check` and :func:`~disnake.ext.commands.app_check_any` decorators.
6 changes: 6 additions & 0 deletions disnake/ext/commands/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar, Union

if TYPE_CHECKING:
from disnake import ApplicationCommandInteraction

from .cog import Cog
from .context import Context
from .errors import CommandError
Expand All @@ -16,6 +18,10 @@
Check = Union[
Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]
]
AppCheck = Union[
Callable[["Cog", "ApplicationCommandInteraction"], MaybeCoro[bool]],
Callable[["ApplicationCommandInteraction"], MaybeCoro[bool]],
]
Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]]
Error = Union[
Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]],
Expand Down
10 changes: 5 additions & 5 deletions disnake/ext/commands/base_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from disnake.interactions import ApplicationCommandInteraction

from ._types import Check, Coro, Error, Hook
from ._types import AppCheck, Coro, Error, Hook
from .cog import Cog

ApplicationCommandInteractionT = TypeVar(
Expand Down Expand Up @@ -155,7 +155,7 @@ def __init__(self, func: CommandCallback, *, name: Optional[str] = None, **kwarg
except AttributeError:
checks = kwargs.get("checks", [])

self.checks: List[Check] = checks
self.checks: List[AppCheck] = checks

try:
cooldown = func.__commands_cooldown__
Expand Down Expand Up @@ -253,10 +253,10 @@ def default_member_permissions(self) -> Optional[Permissions]:
def callback(self) -> CommandCallback:
return self._callback

def add_check(self, func: Check) -> None:
def add_check(self, func: AppCheck) -> None:
"""Adds a check to the application command.
This is the non-decorator interface to :func:`.check`.
This is the non-decorator interface to :func:`.app_check`.
Parameters
----------
Expand All @@ -265,7 +265,7 @@ def add_check(self, func: Check) -> None:
"""
self.checks.append(func)

def remove_check(self, func: Check) -> None:
def remove_check(self, func: AppCheck) -> None:
"""Removes a check from the application command.
This function is idempotent and will not raise an exception
Expand Down
24 changes: 12 additions & 12 deletions disnake/ext/commands/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,30 +789,30 @@ def _inject(self, bot: AnyBot) -> Self:

# Add application command checks
if cls.bot_slash_command_check is not Cog.bot_slash_command_check:
bot.add_app_command_check(self.bot_slash_command_check, slash_commands=True) # type: ignore
bot.add_app_command_check(self.bot_slash_command_check, slash_commands=True)

if cls.bot_user_command_check is not Cog.bot_user_command_check:
bot.add_app_command_check(self.bot_user_command_check, user_commands=True) # type: ignore
bot.add_app_command_check(self.bot_user_command_check, user_commands=True)

if cls.bot_message_command_check is not Cog.bot_message_command_check:
bot.add_app_command_check(self.bot_message_command_check, message_commands=True) # type: ignore
bot.add_app_command_check(self.bot_message_command_check, message_commands=True)

# Add app command one-off checks
if cls.bot_slash_command_check_once is not Cog.bot_slash_command_check_once:
bot.add_app_command_check(
self.bot_slash_command_check_once, # type: ignore
self.bot_slash_command_check_once,
call_once=True,
slash_commands=True,
)

if cls.bot_user_command_check_once is not Cog.bot_user_command_check_once:
bot.add_app_command_check(
self.bot_user_command_check_once, call_once=True, user_commands=True # type: ignore
self.bot_user_command_check_once, call_once=True, user_commands=True
)

if cls.bot_message_command_check_once is not Cog.bot_message_command_check_once:
bot.add_app_command_check(
self.bot_message_command_check_once, # type: ignore
self.bot_message_command_check_once,
call_once=True,
message_commands=True,
)
Expand Down Expand Up @@ -859,32 +859,32 @@ def _eject(self, bot: AnyBot) -> None:

# Remove application command checks
if cls.bot_slash_command_check is not Cog.bot_slash_command_check:
bot.remove_app_command_check(self.bot_slash_command_check, slash_commands=True) # type: ignore
bot.remove_app_command_check(self.bot_slash_command_check, slash_commands=True)

if cls.bot_user_command_check is not Cog.bot_user_command_check:
bot.remove_app_command_check(self.bot_user_command_check, user_commands=True) # type: ignore
bot.remove_app_command_check(self.bot_user_command_check, user_commands=True)

if cls.bot_message_command_check is not Cog.bot_message_command_check:
bot.remove_app_command_check(self.bot_message_command_check, message_commands=True) # type: ignore
bot.remove_app_command_check(self.bot_message_command_check, message_commands=True)

# Remove app command one-off checks
if cls.bot_slash_command_check_once is not Cog.bot_slash_command_check_once:
bot.remove_app_command_check(
self.bot_slash_command_check_once, # type: ignore
self.bot_slash_command_check_once,
call_once=True,
slash_commands=True,
)

if cls.bot_user_command_check_once is not Cog.bot_user_command_check_once:
bot.remove_app_command_check(
self.bot_user_command_check_once, # type: ignore
self.bot_user_command_check_once,
call_once=True,
user_commands=True,
)

if cls.bot_message_command_check_once is not Cog.bot_message_command_check_once:
bot.remove_app_command_check(
self.bot_message_command_check_once, # type: ignore
self.bot_message_command_check_once,
call_once=True,
message_commands=True,
)
Expand Down
50 changes: 49 additions & 1 deletion disnake/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

from disnake.message import Message

from ._types import Check, Coro, CoroFunc, Error, Hook
from ._types import AppCheck, Check, Coro, CoroFunc, Error, Hook


__all__ = (
Expand All @@ -81,6 +81,8 @@
"has_any_role",
"check",
"check_any",
"app_check",
"app_check_any",
"before_invoke",
"after_invoke",
"bot_has_role",
Expand Down Expand Up @@ -1695,6 +1697,9 @@ async def extended_check(ctx):
The function returned by ``predicate`` is **always** a coroutine,
even if the original function was not a coroutine.
.. note::
See :func:`.app_check` for this function's application command counterpart.
.. versionchanged:: 1.3
The ``predicate`` attribute was added.
Expand Down Expand Up @@ -1767,6 +1772,9 @@ def check_any(*checks: Check) -> Callable[[T], T]:
The ``predicate`` attribute for this function **is** a coroutine.
.. note::
See :func:`.app_check_any` for this function's application command counterpart.
.. versionadded:: 1.3
Parameters
Expand Down Expand Up @@ -1823,6 +1831,46 @@ async def predicate(ctx: AnyContext) -> bool:
return check(predicate)


def app_check(predicate: AppCheck) -> Callable[[T], T]:
"""Same as :func:`.check`, but for app commands.
.. versionadded:: 2.10
Parameters
----------
predicate: Callable[[:class:`disnake.ApplicationCommandInteraction`], :class:`bool`]
The predicate to check if the command should be invoked.
"""
return check(predicate) # type: ignore # impl is the same, typings are different


def app_check_any(*checks: AppCheck) -> Callable[[T], T]:
"""Same as :func:`.check_any`, but for app commands.
.. note::
See :func:`.check_any` for this function's prefix command counterpart.
.. versionadded:: 2.10
Parameters
----------
*checks: Callable[[:class:`disnake.ApplicationCommandInteraction`], :class:`bool`]
An argument list of checks that have been decorated with
the :func:`app_check` decorator.
Raises
------
TypeError
A check passed has not been decorated with the :func:`app_check`
decorator.
"""
try:
return check_any(*checks) # type: ignore # impl is the same, typings are different
except TypeError as e:
msg = str(e).replace("commands.check", "commands.app_check") # fix err message
raise TypeError(msg) from None


def has_role(item: Union[int, str]) -> Callable[[T], T]:
"""A :func:`.check` that is added that checks if the member invoking the
command has the role specified via the name or ID specified.
Expand Down
15 changes: 9 additions & 6 deletions disnake/ext/commands/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from disnake.threads import Thread
from disnake.types.snowflake import Snowflake, SnowflakeList

from .context import Context
from .context import AnyContext
from .cooldowns import BucketType, Cooldown
from .flag_converter import Flag

Expand Down Expand Up @@ -181,7 +181,8 @@ class BadArgument(UserInputError):


class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
"""Exception raised when the predicates in :attr:`.Command.checks` or
:attr:`.InvokableApplicationCommand.checks` have failed.
This inherits from :exc:`CommandError`
"""
Expand All @@ -190,7 +191,7 @@ class CheckFailure(CommandError):


class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail.
"""Exception raised when all predicates in :func:`check_any` or :func:`app_check_any` fail.
This inherits from :exc:`CheckFailure`.
Expand All @@ -200,13 +201,15 @@ class CheckAnyFailure(CheckFailure):
----------
errors: List[:class:`CheckFailure`]
A list of errors that were caught during execution.
checks: List[Callable[[:class:`Context`], :class:`bool`]]
checks: List[Callable[[Union[:class:`Context`, :class:`disnake.ApplicationCommandInteraction`]], :class:`bool`]]
A list of check predicates that failed.
"""

def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None:
def __init__(
self, checks: List[CheckFailure], errors: List[Callable[[AnyContext], bool]]
) -> None:
self.checks: List[CheckFailure] = checks
self.errors: List[Callable[[Context], bool]] = errors
self.errors: List[Callable[[AnyContext], bool]] = errors
super().__init__("You do not have permission to run this command.")


Expand Down
14 changes: 7 additions & 7 deletions disnake/ext/commands/interaction_bot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
from disnake.permissions import Permissions

from ._types import Check, CoroFunc
from ._types import AppCheck, CoroFunc
from .base_core import CogT, CommandCallback, InteractionCommandCallback

P = ParamSpec("P")
Expand Down Expand Up @@ -991,7 +991,7 @@ async def on_message_command_error(

def add_app_command_check(
self,
func: Check,
func: AppCheck,
*,
call_once: bool = False,
slash_commands: bool = False,
Expand All @@ -1000,8 +1000,8 @@ def add_app_command_check(
) -> None:
"""Adds a global application command check to the bot.
This is the non-decorator interface to :meth:`.check`,
:meth:`.check_once`, :meth:`.slash_command_check` and etc.
This is the non-decorator interface to :func:`.app_check`,
:meth:`.slash_command_check` and etc.
You must specify at least one of the bool parameters, otherwise
the check won't be added.
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def add_app_command_check(

def remove_app_command_check(
self,
func: Check,
func: AppCheck,
*,
call_once: bool = False,
slash_commands: bool = False,
Expand All @@ -1060,7 +1060,7 @@ def remove_app_command_check(
The function to remove from the global checks.
call_once: :class:`bool`
Whether the function was added with ``call_once=True`` in
the :meth:`.Bot.add_check` call or using :meth:`.check_once`.
the :meth:`.Bot.add_app_command_check` call.
slash_commands: :class:`bool`
Whether this check was for slash commands.
user_commands: :class:`bool`
Expand Down Expand Up @@ -1179,7 +1179,7 @@ def decorator(
) -> Callable[[ApplicationCommandInteraction], Any]:
# T was used instead of Check to ensure the type matches on return
self.add_app_command_check(
func, # type: ignore
func,
call_once=call_once,
slash_commands=slash_commands,
user_commands=user_commands,
Expand Down
6 changes: 6 additions & 0 deletions docs/ext/commands/api/checks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ Functions
.. autofunction:: check_any(*checks)
:decorator:

.. autofunction:: app_check(predicate)
:decorator:

.. autofunction:: app_check_any(*checks)
:decorator:

.. autofunction:: has_role(item)
:decorator:

Expand Down

0 comments on commit 59b101f

Please sign in to comment.