diff --git a/docs/src/API Reference/models/Naff/hybrid_commands.md b/docs/src/API Reference/models/Naff/hybrid_commands.md new file mode 100644 index 000000000..190b56a94 --- /dev/null +++ b/docs/src/API Reference/models/Naff/hybrid_commands.md @@ -0,0 +1 @@ +::: naff.models.naff.hybrid_commands diff --git a/docs/src/API Reference/models/Naff/index.md b/docs/src/API Reference/models/Naff/index.md index d5f3021b6..f44981453 100644 --- a/docs/src/API Reference/models/Naff/index.md +++ b/docs/src/API Reference/models/Naff/index.md @@ -9,6 +9,7 @@ - [Converters](converters) - [Cooldowns](cooldowns) - [Extension](extension) +- [Hybrid Commands](hybrid_commands) - [Listeners](listener) - [Localisation](localisation) - [Prefixed Commands](prefixed_commands) diff --git a/docs/src/Guides/03 Creating Commands.md b/docs/src/Guides/03 Creating Commands.md index 1607c5ca0..3abc832cb 100644 --- a/docs/src/Guides/03 Creating Commands.md +++ b/docs/src/Guides/03 Creating Commands.md @@ -472,3 +472,22 @@ There also is `on_command` which you can overwrite too. That fires on every inte If your bot is complex enough, you might find yourself wanting to use custom models in your commands. To do this, you'll want to use a string option, and define a converter. Information on how to use converters can be found [on the converter page](/Guides/08 Converters). + +## I Want To Make A Prefixed Command Too + +You're in luck! You can use a hybrid command, which is a slash command that also gets converted to an equivalent prefixed command under the hood. + +To use it, simply replace `@slash_command` with `@hybrid_command`, and `InteractionContext` with `HybridContext`, like so: + +```python +@hybrid_command(name="my_command", description="My hybrid command!") +async def my_command_function(ctx: HybridContext): + await ctx.send("Hello World") +``` + +Suggesting you are using the default mention settings for your bot, you should be able to run this command by `@BotPing my_command`. + +As you can see, the only difference between hybrid commands and slash commands, from a developer perspective, is that they use `HybridContext`, which attempts +to seamlessly allow using the same context for slash and prefixed commands. You can always get the underlying context via `inner_context`, though. + +There are only two limitations with them: they only support one attachment option, and they do not support autocomplete. diff --git a/docs/src/Guides/24 Error Tracking.md b/docs/src/Guides/24 Error Tracking.md new file mode 100644 index 000000000..63e9470b8 --- /dev/null +++ b/docs/src/Guides/24 Error Tracking.md @@ -0,0 +1,30 @@ +# Error Tracking + +So, you've finally got your bot running on a server somewhere. Chances are, you're not checking the console output 24/7, looking for exceptions. + +You're going to want to have some way of tracking if errors occur. + +# The simple and dirty method + +!!! Please don't actually do this. + +The most obvious solution is to think "Well, I'm writing a Discord Bot. Why not send my errors to a discord channel?" + +```python + +@listen() +async def on_error(error): + await bot.get_channel(LOGGING_CHANNEL_ID).send(f"```\n{error.source}\n{error.error}\n```) +``` + +And this is great when debugging. But it consumes your rate limit, can run into the 2000 character message limit, and won't work on shards that don't contain your personal server. It's also very hard to notice patterns and can be noisy. + +# So what should I do instead? + +NAFF contains built-in support for Sentry.io, a cloud error tracking platform. + +To enable it, call `bot.load_extension('naff.ext.sentry', token=SENTRY_TOKEN)` as early as possible in your startup. (Load it before your own extensions, so it can catch intitialization errors in those extensions) + +# What does this do that vanilla Sentry doesn't? + +We add some [tags](https://docs.sentry.io/platforms/python/enriching-events/tags/) and [contexts](https://docs.sentry.io/platforms/python/enriching-events/context/) that might be useful, and filter out some internal-errors that you probably don't want to see. diff --git a/naff/api/events/internal.py b/naff/api/events/internal.py index 4c16aa5e3..a0fec6e02 100644 --- a/naff/api/events/internal.py +++ b/naff/api/events/internal.py @@ -21,11 +21,12 @@ def on_guild_join(event): """ import re -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Callable, Coroutine from naff.client.const import MISSING from naff.models.discord.snowflake import to_snowflake from naff.client.utils.attr_utils import define, field, docs +import naff.models as models __all__ = ( "BaseEvent", @@ -70,6 +71,32 @@ def resolved_name(self) -> str: name = self.override_name or self.__class__.__name__ return _event_reg.sub("_", name).lower() + @classmethod + def listen(cls, coro: Callable[..., Coroutine], client: "Client") -> "models.Listener": + """ + A shortcut for creating a listener for this event + + Args: + coro: The coroutine to call when the event is triggered. + client: The client instance to listen to. + + + ??? Hint "Example Usage:" + ```python + class SomeClass: + def __init__(self, bot: Client): + Ready.listen(self.some_func, bot) + + async def some_func(self, event): + print(f"{event.resolved_name} triggered") + ``` + Returns: + A listener object. + """ + listener = models.Listener.create(cls().resolved_name)(coro) + client.add_listener(listener) + return listener + @define(slots=False, kw_only=False) class GuildEvent: diff --git a/naff/api/events/processors/guild_events.py b/naff/api/events/processors/guild_events.py index 8e8345253..1519ce945 100644 --- a/naff/api/events/processors/guild_events.py +++ b/naff/api/events/processors/guild_events.py @@ -41,7 +41,7 @@ async def _on_raw_guild_create(self, event: "RawGatewayEvent") -> None: if self.fetch_members: # noqa # delays events until chunking has completed - await guild.chunk_guild(presences=True) + await guild.chunk() self.dispatch(events.GuildJoin(guild)) diff --git a/naff/api/http/http_client.py b/naff/api/http/http_client.py index 80defcdc2..408c04abc 100644 --- a/naff/api/http/http_client.py +++ b/naff/api/http/http_client.py @@ -35,7 +35,7 @@ ) from naff.client.errors import DiscordError, Forbidden, GatewayNotFound, HTTPException, NotFound, LoginError from naff.client.utils.input_utils import response_decode, OverriddenJson -from naff.client.utils.serializer import dict_filter_missing +from naff.client.utils.serializer import dict_filter from naff.models import CooldownSystem from naff.models.discord.file import UPLOADABLE_TYPE from .route import Route @@ -212,9 +212,9 @@ def _process_payload(payload: dict | list[dict], files: Absent[list[UPLOADABLE_T return None if isinstance(payload, dict): - payload = dict_filter_missing(payload) + payload = dict_filter(payload) else: - payload = [dict_filter_missing(x) if isinstance(x, dict) else x for x in payload] + payload = [dict_filter(x) if isinstance(x, dict) else x for x in payload] if not files: return payload @@ -262,7 +262,7 @@ async def request( if isinstance(payload, (list, dict)) and not files: kwargs["headers"]["Content-Type"] = "application/json" if isinstance(params, dict): - kwargs["params"] = dict_filter_missing(params) + kwargs["params"] = dict_filter(params) lock = self.get_ratelimit(route) # this gets a BucketLock for this route. diff --git a/naff/api/http/http_requests/guild.py b/naff/api/http/http_requests/guild.py index 692890cd4..9475f3a75 100644 --- a/naff/api/http/http_requests/guild.py +++ b/naff/api/http/http_requests/guild.py @@ -3,7 +3,7 @@ import discord_typings from naff.client.const import Absent, MISSING -from naff.client.utils.serializer import dict_filter_missing, dict_filter_none +from naff.client.utils.serializer import dict_filter, dict_filter_none from ..route import Route @@ -661,7 +661,7 @@ async def create_guild( ) -> dict: return await self.request( Route("POST", "/guilds"), - payload=dict_filter_missing( + payload=dict_filter( { "name": name, "icon": icon, diff --git a/naff/api/http/http_requests/interactions.py b/naff/api/http/http_requests/interactions.py index 6da932571..08a2bcd44 100644 --- a/naff/api/http/http_requests/interactions.py +++ b/naff/api/http/http_requests/interactions.py @@ -35,7 +35,7 @@ async def delete_application_command( ) async def get_application_commands( - self, application_id: "Snowflake_Type", guild_id: "Snowflake_Type" + self, application_id: "Snowflake_Type", guild_id: "Snowflake_Type", with_localisations: bool = True ) -> List[discord_typings.ApplicationCommandData]: """ Get all application commands for this application from discord. @@ -43,14 +43,21 @@ async def get_application_commands( Args: application_id: the what application to query guild_id: specify a guild to get commands from + with_localisations: whether to include all localisations in the response Returns: Application command data """ if guild_id == GLOBAL_SCOPE: - return await self.request(Route("GET", f"/applications/{application_id}/commands")) - return await self.request(Route("GET", f"/applications/{application_id}/guilds/{guild_id}/commands")) + return await self.request( + Route("GET", f"/applications/{application_id}/commands"), + params={"with_localizations": int(with_localisations)}, + ) + return await self.request( + Route("GET", f"/applications/{application_id}/guilds/{guild_id}/commands"), + params={"with_localizations": int(with_localisations)}, + ) async def overwrite_application_commands( self, app_id: "Snowflake_Type", data: List[Dict], guild_id: "Snowflake_Type" = None diff --git a/naff/client/client.py b/naff/client/client.py index e427b8177..292e3a651 100644 --- a/naff/client/client.py +++ b/naff/client/client.py @@ -69,6 +69,7 @@ InteractionCommand, SlashCommand, OptionTypes, + HybridCommand, PrefixedCommand, BaseCommand, to_snowflake, @@ -78,6 +79,7 @@ ModalContext, PrefixedContext, AutocompleteContext, + HybridContext, ComponentCommand, Context, application_commands_to_dict, @@ -95,6 +97,7 @@ from naff.models.naff.active_voice_state import ActiveVoiceState from naff.models.naff.application_commands import ModalCommand from naff.models.naff.auto_defer import AutoDefer +from naff.models.naff.hybrid_commands import _prefixed_from_slash, _base_subcommand_generator from naff.models.naff.listener import Listener from naff.models.naff.tasks import Task @@ -215,6 +218,7 @@ class Client( component_context: Type[ComponentContext]: The object to instantiate for Component Context autocomplete_context: Type[AutocompleteContext]: The object to instantiate for Autocomplete Context modal_context: Type[ModalContext]: The object to instantiate for Modal Context + hybrid_context: Type[HybridContext]: The object to instantiate for Hybrid Context global_pre_run_callback: Callable[..., Coroutine]: A coroutine to run before every command is executed global_post_run_callback: Callable[..., Coroutine]: A coroutine to run after every command is executed @@ -259,6 +263,7 @@ def __init__( owner_ids: Iterable["Snowflake_Type"] = (), modal_context: Type[ModalContext] = ModalContext, prefixed_context: Type[PrefixedContext] = PrefixedContext, + hybrid_context: Type[HybridContext] = HybridContext, send_command_tracebacks: bool = True, shard_id: int = 0, status: Status = Status.ONLINE, @@ -317,6 +322,8 @@ def __init__( """The object to instantiate for Autocomplete Context""" self.modal_context: Type[ModalContext] = modal_context """The object to instantiate for Modal Context""" + self.hybrid_context: Type[HybridContext] = hybrid_context + """The object to instantiate for Hybrid Context""" # flags self._ready = asyncio.Event() @@ -488,6 +495,7 @@ def _sanity_check(self) -> None: self.component_context: ComponentContext, self.autocomplete_context: AutocompleteContext, self.modal_context: ModalContext, + self.hybrid_context: HybridContext, } for obj, expected in contexts.items(): if not issubclass(obj, expected): @@ -729,13 +737,6 @@ async def _on_websocket_ready(self, event: events.RawGatewayEvent) -> None: while True: try: # wait to let guilds cache await asyncio.wait_for(self._guild_event.wait(), self.guild_event_timeout) - if self.fetch_members: - # ensure all guilds have completed chunking - for guild in self.guilds: - if guild and not guild.chunked.is_set(): - logger.debug(f"Waiting for {guild.id} to chunk") - await guild.chunked.wait() - except asyncio.TimeoutError: logger.warning("Timeout waiting for guilds cache: Not all guilds will be in cache") break @@ -1105,6 +1106,43 @@ def add_interaction(self, command: InteractionCommand) -> bool: return True + def add_hybrid_command(self, command: HybridCommand) -> bool: + if self.debug_scope: + command.scopes = [self.debug_scope] + + if command.callback is None: + return False + + if command.is_subcommand: + prefixed_base = self.prefixed_commands.get(str(command.name)) + if not prefixed_base: + prefixed_base = _base_subcommand_generator( + str(command.name), list(command.name.to_locale_dict().values()), str(command.description) + ) + self.add_prefixed_command(prefixed_base) + + if command.group_name: # if this is a group command + _prefixed_cmd = prefixed_base + prefixed_base = prefixed_base.subcommands.get(str(command.group_name)) + + if not prefixed_base: + prefixed_base = _base_subcommand_generator( + str(command.group_name), + list(command.group_name.to_locale_dict().values()), + str(command.group_description), + group=True, + ) + _prefixed_cmd.add_command(prefixed_base) + + new_command = _prefixed_from_slash(command) + new_command._parse_parameters() + prefixed_base.add_command(new_command) + else: + new_command = _prefixed_from_slash(command) + self.add_prefixed_command(new_command) + + return self.add_interaction(command) + def add_prefixed_command(self, command: PrefixedCommand) -> None: """ Add a prefixed command to the client. @@ -1175,6 +1213,8 @@ def process(_cmds) -> None: self.add_modal_callback(func) elif isinstance(func, ComponentCommand): self.add_component_callback(func) + elif isinstance(func, HybridCommand): + self.add_hybrid_command(func) elif isinstance(func, InteractionCommand): self.add_interaction(func) elif ( diff --git a/naff/client/utils/serializer.py b/naff/client/utils/serializer.py index c68325e7c..eb2c499bf 100644 --- a/naff/client/utils/serializer.py +++ b/naff/client/utils/serializer.py @@ -9,7 +9,7 @@ from naff.client.const import MISSING, T from naff.models.discord.file import UPLOADABLE_TYPE, File -__all__ = ("no_export_meta", "export_converter", "to_dict", "dict_filter_none", "dict_filter_missing", "to_image_data") +__all__ = ("no_export_meta", "export_converter", "to_dict", "dict_filter_none", "dict_filter", "to_image_data") no_export_meta = {"no_export": True} @@ -95,9 +95,9 @@ def dict_filter_none(data: dict) -> dict: return {k: v for k, v in data.items() if v is not None} -def dict_filter_missing(data: dict) -> dict: +def dict_filter(data: dict) -> dict: """ - Filters out all values that are MISSING sentinel. + Filters out all values that are MISSING sentinel and converts all sets to lists. Args: data: The dict data to filter. @@ -106,7 +106,13 @@ def dict_filter_missing(data: dict) -> dict: The filtered dict data. """ - return {k: v for k, v in data.items() if v is not MISSING} + filtered = data.copy() + for k, v in data.items(): + if v is MISSING: + filtered.pop(k) + elif isinstance(v, set): + filtered[k] = list(v) + return filtered def to_image_data(imagefile: Optional[UPLOADABLE_TYPE]) -> Optional[str]: diff --git a/naff/ext/sentry.py b/naff/ext/sentry.py new file mode 100644 index 000000000..d426866d8 --- /dev/null +++ b/naff/ext/sentry.py @@ -0,0 +1,82 @@ +""" +Sets up a Sentry Logger + +And then call `bot.load_extension('naff.ext.sentry', token=SENTRY_TOKEN)` +Optionally takes a filter function that will be called before sending the event to Sentry. +""" +import logging +from typing import Any, Callable, Optional + +from naff.api.events.internal import Error +from naff.client.const import logger + +try: + import sentry_sdk +except ModuleNotFoundError: + logger.error("sentry-sdk not installed, cannot enable sentry integration. Install with `pip install naff[sentry]`") + raise + +from naff import Extension, Client, listen + + +__all__ = ("setup", "default_sentry_filter") + + +def default_sentry_filter(event: dict[str, Any], hint: dict[str, Any]) -> Optional[dict[str, Any]]: + if "log_record" in hint: + record: logging.LogRecord = hint["log_record"] + if "naff" in record.name: + # There are some logging messages that are not worth sending to sentry. + if ": 403" in record.message: + return None + if record.message.startswith("Ignoring exception in "): + return None + + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if isinstance(exc_value, KeyboardInterrupt): + # We don't need to report a ctrl+c + return None + return event + + +class SentryExtension(Extension): + @listen() + async def on_startup(self) -> None: + sentry_sdk.set_context( + "bot", + { + "name": str(self.bot.user), + "intents": repr(self.bot.intents), + }, + ) + sentry_sdk.set_tag("bot_name", str(self.bot.user)) + + @listen() + async def on_error(self, event: Error) -> None: + with sentry_sdk.configure_scope() as scope: + scope.set_tag("source", event.source) + if event.ctx: + scope.set_context( + type(event.ctx).__name__, + { + "args": event.ctx.args, + "kwargs": event.ctx.kwargs, + "message": event.ctx.message, + }, + ) + sentry_sdk.capture_exception(event.error) + + +def setup( + bot: Client, + token: str = None, + filter: Optional[Callable[[dict[str, Any], dict[str, Any]], Optional[dict[str, Any]]]] = None, +) -> None: + if not token: + logger.error("Cannot enable sentry integration, no token provided") + return + if filter is None: + filter = default_sentry_filter + sentry_sdk.init(token, before_send=filter) + SentryExtension(bot) diff --git a/naff/models/discord/auto_mod.py b/naff/models/discord/auto_mod.py index 2577ca567..e9d2973bc 100644 --- a/naff/models/discord/auto_mod.py +++ b/naff/models/discord/auto_mod.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Union, Any, Optional +from typing import Any, Optional, TYPE_CHECKING import attrs @@ -286,9 +286,9 @@ class AutoModerationAction(ClientObject): matched_content: Optional[str] = field(default=None) content: Optional[str] = field(default=None) - _message_id: Union["Snowflake_Type", None] = field(default=None) - _alert_system_message_id: "Snowflake_Type" = field() - _channel_id: "Snowflake_Type" = field() + _message_id: Optional["Snowflake_Type"] = field(default=None) + _alert_system_message_id: Optional["Snowflake_Type"] = field(default=None) + _channel_id: Optional["Snowflake_Type"] = field(default=None) _guild_id: "Snowflake_Type" = field() @classmethod @@ -302,11 +302,11 @@ def guild(self) -> "Guild": return self._client.get_guild(self._guild_id) @property - def channel(self) -> "GuildText": + def channel(self) -> "Optional[GuildText]": return self._client.get_channel(self._channel_id) @property - def message(self) -> "Message": + def message(self) -> "Optional[Message]": return self._client.cache.get_message(self._channel_id, self._message_id) diff --git a/naff/models/discord/guild.py b/naff/models/discord/guild.py index 8e8d9e23d..8c2e8f8b2 100644 --- a/naff/models/discord/guild.py +++ b/naff/models/discord/guild.py @@ -3,7 +3,7 @@ from collections import namedtuple from functools import cmp_to_key from typing import List, Optional, Union, Set, Dict, Any, TYPE_CHECKING - +from warnings import warn import naff.models as models from naff.client.const import MISSING, PREMIUM_GUILD_LIMITS, logger, Absent @@ -498,20 +498,71 @@ async def edit_nickname(self, new_nickname: Absent[str] = MISSING, reason: Absen """ await self.me.edit_nickname(new_nickname, reason=reason) - async def chunk_guild(self, wait=True, presences=False) -> None: + async def http_chunk(self) -> None: + """Populates all members of this guild using the REST API.""" + start_time = time.perf_counter() + members = [] + + # request all guild members + after = MISSING + while True: + if members: + after = members[-1]["user"]["id"] + rcv: list = await self._client.http.list_members(self.id, limit=1000, after=after) + members.extend(rcv) + if len(rcv) < 1000: + # we're done + break + + # process all members + s = time.monotonic() + for member in members: + self._client.cache.place_member_data(self.id, member) + if (time.monotonic() - s) > 0.05: + # look, i get this *could* be a thread, but because it needs to modify data in the main thread, + # it is still blocking. So by periodically yielding to the event loop, we can avoid blocking, and still + # process this data properly + await asyncio.sleep(0) + s = time.monotonic() + self.chunked.set() + logger.info(f"Cached {len(members)} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds") + + async def gateway_chunk(self, wait=True, presences=True) -> None: """ Trigger a gateway `get_members` event, populating this object with members. Args: wait: Wait for chunking to be completed before continuing presences: Do you need presence data for members? - """ ws = self._client.get_guild_websocket(self.id) await ws.request_member_chunks(self.id, limit=0, presences=presences) if wait: await self.chunked.wait() + async def chunk(self) -> None: + """Populates all members of this guild using the REST API.""" + await self.http_chunk() + + async def chunk_guild(self, wait=True, presences=True) -> None: + """ + Trigger a gateway `get_members` event, populating this object with members. + + !!! warning "Depreciation Warning" + Gateway chunking is deprecated and replaced by http chunking. Use `guild.gateway_chunk` if you need gateway chunking. + + Args: + wait: Wait for chunking to be completed before continuing + presences: Do you need presence data for members? + + """ + warn( + "Gateway chunking is deprecated and replaced by http chunking. Use `guild.gateway_chunk` if you need gateway chunking.", + DeprecationWarning, + stacklevel=2, + ) + await self.gateway_chunk(wait=wait, presences=presences) + async def process_member_chunk(self, chunk: dict) -> None: """ Receive and either cache or process the chunks of members from gateway. diff --git a/naff/models/discord/role.py b/naff/models/discord/role.py index 0aad7b0e4..8349571a5 100644 --- a/naff/models/discord/role.py +++ b/naff/models/discord/role.py @@ -6,7 +6,7 @@ from naff.client.const import MISSING, Absent, T from naff.client.utils.attr_utils import define, field from naff.client.utils.attr_converters import optional as optional_c -from naff.client.utils.serializer import dict_filter_missing +from naff.client.utils.serializer import dict_filter from naff.models.discord.asset import Asset from naff.models.discord.emoji import PartialEmoji from naff.models.discord.color import Color @@ -190,7 +190,7 @@ async def edit( if isinstance(color, Color): color = color.value - payload = dict_filter_missing( + payload = dict_filter( {"name": name, "permissions": permissions, "color": color, "hoist": hoist, "mentionable": mentionable} ) diff --git a/naff/models/naff/__init__.py b/naff/models/naff/__init__.py index 8c0ed8907..71af4f0a7 100644 --- a/naff/models/naff/__init__.py +++ b/naff/models/naff/__init__.py @@ -12,5 +12,6 @@ from .prefixed_commands import * from .protocols import * from .extension import * +from .hybrid_commands import * from .wait import * from .tasks import * diff --git a/naff/models/naff/application_commands.py b/naff/models/naff/application_commands.py index aa674aad3..9a54f51d5 100644 --- a/naff/models/naff/application_commands.py +++ b/naff/models/naff/application_commands.py @@ -1051,8 +1051,8 @@ def _compare_commands(local_cmd: dict, remote_cmd: dict) -> bool: "description": ("description", ""), "default_member_permissions": ("default_member_permissions", None), "dm_permission": ("dm_permission", True), - "name_localized": ("name_localizations", {}), - "description_localized": ("description_localizations", {}), + "name_localized": ("name_localizations", None), + "description_localized": ("description_localizations", None), } for local_name, comparison_data in lookup.items(): @@ -1068,12 +1068,15 @@ def _compare_options(local_opt_list: dict, remote_opt_list: dict) -> bool: "description": ("description", ""), "required": ("required", False), "autocomplete": ("autocomplete", False), - "name_localized": ("name_localizations", {}), - "description_localized": ("description_localizations", {}), + "name_localized": ("name_localizations", None), + "description_localized": ("description_localizations", None), "choices": ("choices", []), "max_value": ("max_value", None), "min_value": ("min_value", None), } + post_process: Dict[str, Callable] = { + "choices": lambda l: [d | {"name_localizations": {}} if len(d) == 2 else d for d in l], + } if local_opt_list != remote_opt_list: if len(local_opt_list) != len(remote_opt_list): @@ -1084,14 +1087,16 @@ def _compare_options(local_opt_list: dict, remote_opt_list: dict) -> bool: if local_option["type"] == remote_option["type"]: if local_option["type"] in (OptionTypes.SUB_COMMAND_GROUP, OptionTypes.SUB_COMMAND): - if not _compare_commands(local_option, remote_option) or _compare_options( + if not _compare_commands(local_option, remote_option) or not _compare_options( local_option.get("options", []), remote_option.get("options", []) ): return False else: for local_name, comparison_data in options_lookup.items(): remote_name, default_value = comparison_data - if local_option.get(local_name, default_value) != remote_option.get(remote_name, default_value): + if local_option.get(local_name, default_value) != post_process.get(remote_name, lambda l: l)( + remote_option.get(remote_name, default_value) + ): return False else: diff --git a/naff/models/naff/command.py b/naff/models/naff/command.py index 31f5ef887..99ba87478 100644 --- a/naff/models/naff/command.py +++ b/naff/models/naff/command.py @@ -122,7 +122,7 @@ async def __call__(self, context: "Context", *args, **kwargs) -> None: if self.error_callback: await self.error_callback(e, context, *args, **kwargs) elif self.extension and self.extension.extension_error: - await self.extension.extension_error(context, *args, **kwargs) + await self.extension.extension_error(e, context, *args, **kwargs) else: raise finally: diff --git a/naff/models/naff/context.py b/naff/models/naff/context.py index dc2d6c04e..7fa012fb9 100644 --- a/naff/models/naff/context.py +++ b/naff/models/naff/context.py @@ -44,6 +44,7 @@ "AutocompleteContext", "ModalContext", "PrefixedContext", + "HybridContext", "SendableContext", ) @@ -664,6 +665,203 @@ async def _send_http_request( return await self._client.http.create_message(message_payload, self.channel.id, files=files) +@define +class HybridContext(Context): + """ + Represents the context for hybrid commands, a slash command that can also be used as a prefixed command. + + This attempts to create a compatibility layer to allow contexts for an interaction or a message to be used seamlessly. + """ + + deferred: bool = field(default=False, metadata=docs("Is this context deferred?")) + responded: bool = field(default=False, metadata=docs("Have we responded to this?")) + app_permissions: Permissions = field( + default=0, converter=Permissions, metadata=docs("The permissions this context has") + ) + + _interaction_context: Optional[InteractionContext] = field(default=None) + _prefixed_context: Optional[PrefixedContext] = field(default=None) + + @classmethod + def from_interaction_context(cls, context: InteractionContext) -> "HybridContext": + return cls( + client=context._client, # type: ignore + interaction_context=context, # type: ignore + invoke_target=context.invoke_target, + command=context.command, + args=context.args, + kwargs=context.kwargs, + author=context.author, + channel=context.channel, + guild_id=context.guild_id, + deferred=context.deferred, + responded=context.responded, + app_permissions=context.app_permissions, + ) + + @classmethod + def from_prefixed_context(cls, context: PrefixedContext) -> "HybridContext": + # this is a "best guess" on what the permissions are + # this may or may not be totally accurate + if hasattr(context.channel, "permissions_for"): + app_permissions = context.channel.permissions_for(context.guild.me) # type: ignore + elif context.channel.type in {10, 11, 12}: # it's a thread + app_permissions = context.channel.parent_channel.permissions_for(context.guild.me) # type: ignore + else: + # this is what happens with interaction contexts in dms + app_permissions = 0 + + return cls( + client=context._client, # type: ignore + prefixed_context=context, # type: ignore + invoke_target=context.invoke_target, + command=context.command, + args=context.args, + kwargs=context.kwargs, # this is usually empty + author=context.author, + channel=context.channel, + guild_id=context.guild_id, + message=context.message, + app_permissions=app_permissions, + ) + + @property + def inner_context(self) -> InteractionContext | PrefixedContext: + """ + Returns the context powering the current hybrid context. + + This can be used for scope-specific actions, like sending modals in an interaction. + """ + return self._interaction_context or self._prefixed_context # type: ignore + + @property + def ephemeral(self) -> bool: + """Returns if responses to this interaction are ephemeral, if this is an interaction. Otherwise, returns False.""" + return self._interaction_context.ephemeral if self._interaction_context else False + + @property + def expires_at(self) -> Optional[Timestamp]: + """The timestamp the context is expected to expire at, or None if the context never expires.""" + if not self._interaction_context: + return None + + if self.responded: + return Timestamp.from_snowflake(self._interaction_context.interaction_id) + datetime.timedelta(minutes=15) + return Timestamp.from_snowflake(self._interaction_context.interaction_id) + datetime.timedelta(seconds=3) + + @property + def expired(self) -> bool: + """Has the context expired yet?""" + return Timestamp.utcnow() >= self.expires_at if self.expires_at else False + + @property + def invoked_name(self) -> str: + return ( + self.command.get_localised_name(self._interaction_context.locale) + if self._interaction_context + else self.invoke_target + ) + + async def defer(self, ephemeral: bool = False) -> None: + """ + Either defers the response (if used in an interaction) or triggers a typing indicator for 10 seconds (if used for messages). + + Args: + ephemeral: Should the response be ephemeral? Only applies to responses for interactions. + + """ + if self._interaction_context: + await self._interaction_context.defer(ephemeral=ephemeral) + else: + await self.channel.trigger_typing() + + self.deferred = True + + async def reply( + self, + content: Optional[str] = None, + embeds: Optional[Union[List[Union["Embed", dict]], Union["Embed", dict]]] = None, + embed: Optional[Union["Embed", dict]] = None, + **kwargs, + ) -> "Message": + """ + Reply to this message, takes all the same attributes as `send`. + + For interactions, this functions the same as `send`. + """ + kwargs = locals() + kwargs.pop("self") + extra_kwargs = kwargs.pop("kwargs") + kwargs |= extra_kwargs + + if self._interaction_context: + result = await self._interaction_context.send(**kwargs) + else: + kwargs.pop("ephemeral", None) + result = await self._prefixed_context.reply(**kwargs) # type: ignore + + self.responded = True + return result + + async def send( + self, + content: Optional[str] = None, + embeds: Optional[Union[List[Union["Embed", dict]], Union["Embed", dict]]] = None, + embed: Optional[Union["Embed", dict]] = None, + components: Optional[ + Union[ + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ] + ] = None, + stickers: Optional[Union[List[Union["Sticker", "Snowflake_Type"]], "Sticker", "Snowflake_Type"]] = None, + allowed_mentions: Optional[Union["AllowedMentions", dict]] = None, + reply_to: Optional[Union["MessageReference", "Message", dict, "Snowflake_Type"]] = None, + file: Optional[Union["File", "IOBase", "Path", str]] = None, + tts: bool = False, + flags: Optional[Union[int, "MessageFlags"]] = None, + ephemeral: bool = False, + **kwargs, + ) -> "Message": + """ + Send a message. + + Args: + content: Message text content. + embeds: Embedded rich content (up to 6000 characters). + embed: Embedded rich content (up to 6000 characters). + components: The components to include with the message. + stickers: IDs of up to 3 stickers in the server to send in the message. + allowed_mentions: Allowed mentions for the message. + reply_to: Message to reference, must be from the same channel. + files: Files to send, the path, bytes or File() instance, defaults to None. You may have up to 10 files. + file: Files to send, the path, bytes or File() instance, defaults to None. You may have up to 10 files. + tts: Should this message use Text To Speech. + suppress_embeds: Should embeds be suppressed on this send + flags: Message flags to apply. + ephemeral: Should this message be sent as ephemeral (hidden) - only works with interactions. + + Returns: + New message object that was sent. + + """ + kwargs = locals() + kwargs.pop("self") + extra_kwargs = kwargs.pop("kwargs") + kwargs |= extra_kwargs + + if self._interaction_context: + result = await self._interaction_context.send(**kwargs) + else: + kwargs.pop("ephemeral", None) + result = await self._prefixed_context.send(**kwargs) # type: ignore + + self.responded = True + return result + + @runtime_checkable class SendableContext(Protocol): """ diff --git a/naff/models/naff/extension.py b/naff/models/naff/extension.py index 3d5efa433..f83dc8201 100644 --- a/naff/models/naff/extension.py +++ b/naff/models/naff/extension.py @@ -86,6 +86,8 @@ def __new__(cls, bot: "Client", *args, **kwargs) -> "Extension": bot.add_modal_callback(val) elif isinstance(val, naff.ComponentCommand): bot.add_component_callback(val) + elif isinstance(val, naff.HybridCommand): + bot.add_hybrid_command(val) elif isinstance(val, naff.InteractionCommand): bot.add_interaction(val) else: @@ -143,6 +145,49 @@ def drop(self) -> None: for scope in func.scopes: if self.bot.interactions.get(scope): self.bot.interactions[scope].pop(func.resolved_name, []) + + if isinstance(func, naff.HybridCommand): + # here's where things get complicated - we need to unload the prefixed command + # by necessity, there's a lot of logic here to determine what needs to be unloaded + if not func.callback: # not like it was added + return + + if func.is_subcommand: + prefixed_base = self.bot.prefixed_commands.get(str(func.name)) + _base_cmd = prefixed_base + if not prefixed_base: + # if something weird happened here, here's a safeguard + continue + + if func.group_name: + prefixed_base = prefixed_base.subcommands.get(str(func.group_name)) + if not prefixed_base: + continue + + prefixed_base.remove_command(str(func.sub_cmd_name)) + + if not prefixed_base.subcommands: + # the base cmd is now empty, delete it + if func.group_name: + _base_cmd.remove_command(str(func.group_name)) # type: ignore + + # and now the base command is empty + if not _base_cmd.subcommands: # type: ignore + # in case you're curious, i did try to put the below behavior + # in a function here, but then it turns out a weird python + # bug can happen if i did that + if cmd := self.bot.prefixed_commands.pop(str(func.name), None): + for alias in cmd.aliases: + self.bot.prefixed_commands.pop(alias, None) + + elif cmd := self.bot.prefixed_commands.pop(str(func.name), None): + for alias in cmd.aliases: + self.bot.prefixed_commands.pop(alias, None) + + elif cmd := self.bot.prefixed_commands.pop(str(func.name), None): + for alias in cmd.aliases: + self.bot.prefixed_commands.pop(alias, None) + elif isinstance(func, naff.PrefixedCommand): if not func.is_subcommand: self.bot.prefixed_commands.pop(func.name, None) diff --git a/naff/models/naff/hybrid_commands.py b/naff/models/naff/hybrid_commands.py new file mode 100644 index 000000000..cc2b5ecb1 --- /dev/null +++ b/naff/models/naff/hybrid_commands.py @@ -0,0 +1,565 @@ +import inspect +import functools +import asyncio + +from typing import Any, Callable, Coroutine, TYPE_CHECKING, Optional, TypeGuard + + +from naff.client.const import Absent, GLOBAL_SCOPE, MISSING, T +from naff.client.errors import BadArgument +from naff.client.utils.attr_utils import define, field +from naff.client.utils.misc_utils import get_object_name, maybe_coroutine +from naff.models.naff.command import BaseCommand +from naff.models.naff.application_commands import ( + SlashCommand, + LocalisedName, + LocalisedDesc, + SlashCommandOption, + SlashCommandChoice, + OptionTypes, +) +from naff.models.naff.prefixed_commands import _convert_to_bool, PrefixedCommand +from naff.models.naff.protocols import Converter +from naff.models.naff.converters import ( + _LiteralConverter, + NoArgumentConverter, + MemberConverter, + UserConverter, + RoleConverter, + BaseChannelConverter, +) +from naff.models.naff.context import HybridContext, InteractionContext, PrefixedContext + +if TYPE_CHECKING: + from naff.models.naff.checks import TYPE_CHECK_FUNCTION + from naff.models.discord.channel import BaseChannel + from naff.models.discord.message import Attachment + from naff.models.discord.enums import Permissions, ChannelTypes + from naff.models.discord.snowflake import Snowflake_Type + +__all__ = ("HybridCommand", "hybrid_command", "hybrid_subcommand") + +_get_converter_function = BaseCommand._get_converter_function + + +def _check_if_annotation(param_annotation: Any, type_: type[T]) -> TypeGuard[T]: + return ( + isinstance(param_annotation, type_) or inspect.isclass(param_annotation) and issubclass(param_annotation, type_) + ) + + +def _create_subcmd_func(group: bool = False) -> Callable: + async def _subcommand_base(*args, **kwargs) -> None: + if group: + raise BadArgument("Cannot run this subcommand group without a valid subcommand.") + else: + raise BadArgument("Cannot run this command without a valid subcommand.") + + return _subcommand_base + + +def _generate_permission_check(permissions: "Permissions") -> "TYPE_CHECK_FUNCTION": + async def _permission_check(ctx: HybridContext) -> bool: + return ctx.author.has_permission(*permissions) if ctx.guild_id else True # type: ignore + + return _permission_check # type: ignore + + +def _generate_scope_check(_scopes: list["Snowflake_Type"]) -> "TYPE_CHECK_FUNCTION": + scopes = frozenset(int(s) for s in _scopes) + + async def _scope_check(ctx: HybridContext) -> bool: + return int(ctx.guild_id) in scopes + + return _scope_check # type: ignore + + +async def _guild_check(ctx: HybridContext) -> bool: + return bool(ctx.guild_id) + + +def _match_option_type(option_type: int) -> Callable[[HybridContext, Any], Any]: + if option_type == 3: + return lambda ctx, arg: str(arg) + if option_type == 4: + return lambda ctx, arg: int(arg) + if option_type == 5: + return lambda ctx, arg: _convert_to_bool(arg) + if option_type == 6: + return _get_converter_function(_UnionConverter(MemberConverter, UserConverter), "") + if option_type == 7: + return _get_converter_function(BaseChannelConverter, "") + if option_type == 8: + return _get_converter_function(RoleConverter, "") + if option_type == 9: + return _get_converter_function(_UnionConverter(MemberConverter, UserConverter, RoleConverter), "") + if option_type == 10: + return lambda ctx, arg: float(arg) + if option_type == 11: + return _BasicAttachmentConverter # type: ignore + + raise ValueError(f"{option_type} is an unsupported option type right now.") + + +class _UnionConverter(Converter): + def __init__(self, *converters: type[Converter]) -> None: + self._converters = converters + + async def convert(self, ctx: HybridContext, arg: Any) -> Any: + for converter in self._converters: + try: + return await converter().convert(ctx, arg) + except Exception: # noqa + continue + + union_names = tuple(get_object_name(t).removesuffix("Converter") for t in self._converters) + union_types_str = ", ".join(union_names[:-1]) + f", or {union_names[-1]}" + raise BadArgument(f'Could not convert "{arg}" into {union_types_str}.') + + +class _BasicAttachmentConverter(NoArgumentConverter): + def convert(self, ctx: HybridContext, _: Any) -> "Attachment": + try: + return ctx.message.attachments[0] + except (AttributeError, IndexError): + raise BadArgument("No attachment found.") from None + + +class _ChoicesConverter(_LiteralConverter): + values: dict + choice_values: dict + + def __init__(self, choices: list[SlashCommandChoice | dict]) -> None: + standardized_choices = tuple((SlashCommandChoice(**o) if isinstance(o, dict) else o) for o in choices) + + names = tuple(c.name for c in standardized_choices) + self.values = {str(arg): str for arg in names} + self.choice_values = {str(c.name): c.value for c in standardized_choices} + + async def convert(self, ctx: HybridContext, argument: str) -> Any: + val = await super().convert(ctx, argument) + return self.choice_values[val] + + +class _StringLengthConverter(Converter[str]): + def __init__(self, min_length: Optional[int], max_length: Optional[int]) -> None: + self.min_length = min_length + self.max_length = max_length + + async def convert(self, ctx: HybridContext, argument: str) -> str: + if self.min_length and len(argument) < self.min_length: + raise BadArgument(f'The string "{argument}" is shorter than {self.min_length} character(s).') + elif self.max_length and len(argument) > self.max_length: + raise BadArgument(f'The string "{argument}" is longer than {self.max_length} character(s).') + + return argument + + +class _RangeConverter(Converter[float | int]): + def __init__( + self, + number_convert: Callable[[HybridContext, Any], Any], + number_type: int, + min_value: Optional[float | int], + max_value: Optional[float | int], + ) -> None: + self.number_convert = number_convert + self.number_type = number_type + self.min_value = min_value + self.max_value = max_value + + async def convert(self, ctx: HybridContext, argument: str) -> float | int: + try: + converted: float | int = await maybe_coroutine(self.number_convert, ctx, argument) + + if self.min_value and converted < self.min_value: + raise BadArgument(f'Value "{argument}" is less than {self.min_value}.') + if self.max_value and converted > self.max_value: + raise BadArgument(f'Value "{argument}" is greater than {self.max_value}.') + + return converted + except ValueError: + type_name = "number" if self.number_type == OptionTypes.NUMBER else "integer" + + if type_name.startswith("i"): + raise BadArgument(f'Argument "{argument}" is not an {type_name}.') from None + else: + raise BadArgument(f'Argument "{argument}" is not a {type_name}.') from None + except BadArgument: + raise + + +class _NarrowedChannelConverter(BaseChannelConverter): + def __init__(self, channel_types: "list[ChannelTypes | int]") -> None: + self.channel_types = channel_types + + async def convert(self, ctx: HybridContext, argument: str) -> "BaseChannel": + channel = await super().convert(ctx, argument) + if channel.type not in self.channel_types: + raise BadArgument(f'Channel "{channel.mention}" is not an allowed channel type.') + return channel + + +class _StackedConverter(Converter): + def __init__( + self, + ori_converter_func: Callable[[HybridContext, Any], Any], + additional_converter_func: Callable[[HybridContext, Any], Any], + ) -> None: + self._ori_converter_func = ori_converter_func + self._additional_converter_func = additional_converter_func + + async def convert(self, ctx: HybridContext, argument: Any) -> Any: + part_one = await maybe_coroutine(self._ori_converter_func, ctx, argument) + return await maybe_coroutine(self._additional_converter_func, ctx, part_one) + + +class _StackedNoArgConverter(NoArgumentConverter): + def __init__( + self, + ori_converter_func: Callable[[HybridContext, Any], Any], + additional_converter_func: Callable[[HybridContext, Any], Any], + ) -> None: + self._ori_converter_func = ori_converter_func + self._additional_converter_func = additional_converter_func + + async def convert(self, ctx: HybridContext, _: Any) -> Any: + part_one = await maybe_coroutine(self._ori_converter_func, ctx, _) + return await maybe_coroutine(self._additional_converter_func, ctx, part_one) + + +@define() +class HybridCommand(SlashCommand): + """A subclass of SlashCommand that handles the logic for hybrid commands.""" + + async def __call__(self, context: InteractionContext, *args, **kwargs) -> None: + new_ctx = context.bot.hybrid_context.from_interaction_context(context) + return await super().__call__(new_ctx, *args, **kwargs) + + def group(self, name: str = None, description: str = "No Description Set") -> "HybridCommand": + return HybridCommand( + name=self.name, + description=self.description, + group_name=name, + group_description=description, + scopes=self.scopes, + ) + + def subcommand( + self, + sub_cmd_name: LocalisedName | str, + group_name: LocalisedName | str = None, + sub_cmd_description: Absent[LocalisedDesc | str] = MISSING, + group_description: Absent[LocalisedDesc | str] = MISSING, + options: list[SlashCommandOption | dict] = None, + nsfw: bool = False, + ) -> Callable[..., "HybridCommand"]: + def wrapper(call: Callable[..., Coroutine]) -> "HybridCommand": + nonlocal sub_cmd_description + + if not asyncio.iscoroutinefunction(call): + raise TypeError("Subcommand must be coroutine") + + if sub_cmd_description is MISSING: + sub_cmd_description = call.__doc__ or "No Description Set" + + return HybridCommand( + name=self.name, + description=self.description, + group_name=group_name or self.group_name, + group_description=group_description or self.group_description, + sub_cmd_name=sub_cmd_name, + sub_cmd_description=sub_cmd_description, + default_member_permissions=self.default_member_permissions, + dm_permission=self.dm_permission, + options=options, + callback=call, + scopes=self.scopes, + nsfw=nsfw, + ) + + return wrapper + + +@define() +class _HybridPrefixedCommand(PrefixedCommand): + _uses_subcommand_func: bool = field(default=False) + + async def __call__(self, context: PrefixedContext, *args, **kwargs) -> None: + new_ctx = context.bot.hybrid_context.from_prefixed_context(context) + return await super().__call__(new_ctx, *args, **kwargs) + + def add_command(self, cmd: "_HybridPrefixedCommand") -> None: + super().add_command(cmd) + + if not self._uses_subcommand_func: + self.callback = _create_subcmd_func(group=self.is_subcommand) + self.parameters = [] + self.ignore_extra = False + self._inspect_signature = inspect.Signature(None) + self._uses_subcommand_func = True + + +def _base_subcommand_generator( + name: str, aliases: list[str], description: str, group: bool = False +) -> _HybridPrefixedCommand: + return _HybridPrefixedCommand( + callback=_create_subcmd_func(group=group), + name=name, + aliases=aliases, + help=description, + ignore_extra=False, + inspect_signature=inspect.Signature(None), # type: ignore + ) + + +def _prefixed_from_slash(cmd: SlashCommand) -> _HybridPrefixedCommand: + new_parameters: list[inspect.Parameter] = [] + + if cmd.options: + if cmd.has_binding: + old_func = functools.partial(cmd.callback, None, None) + else: + old_func = functools.partial(cmd.callback, None) + + old_params = dict(inspect.signature(old_func).parameters) + attachment_option = False + + standardized_options = ((SlashCommandOption(**o) if isinstance(o, dict) else o) for o in cmd.options) + for option in standardized_options: + annotation = _match_option_type(option.type) + + if option.type == OptionTypes.ATTACHMENT: + if attachment_option: + raise ValueError("Cannot have multiple attachment options.") + else: + attachment_option = True + + if option.autocomplete: + # there isn't much we can do here + raise ValueError("Cannot use autocomplete in hybrid commands.") + + if option.type in {OptionTypes.STRING, OptionTypes.INTEGER, OptionTypes.NUMBER} and option.choices: + annotation = _ChoicesConverter(option.choices).convert + elif option.type in {OptionTypes.INTEGER, OptionTypes.NUMBER} and ( + option.min_value is not None or option.max_value is not None + ): + annotation = _RangeConverter(annotation, option.type, option.min_value, option.max_value).convert + elif option.type == OptionTypes.STRING and (option.min_length is not None or option.max_length is not None): + annotation = _StringLengthConverter(option.min_length, option.max_length).convert + elif option.type == OptionTypes.CHANNEL and option.channel_types: + annotation = _NarrowedChannelConverter(option.channel_types).convert + + if ori_param := old_params.pop(str(option.name), None): + if ori_param.annotation != inspect._empty and _check_if_annotation(ori_param.annotation, Converter): + if option.type != OptionTypes.ATTACHMENT: + annotation = _StackedConverter( + annotation, _get_converter_function(ori_param.annotation, str(option.name)) # type: ignore + ) + else: + annotation = _StackedNoArgConverter( + _get_converter_function(annotation, ""), _get_converter_function(ori_param.annotation, str(option.name)) # type: ignore + ) + + if not option.required and ori_param.default == inspect._empty: + # prefixed commands would automatically fill this in, slash commands don't + # technically, there would be an error for this, but it would + # be unclear + raise ValueError("Optional options must have a default value.") + + default = inspect._empty if option.required else ori_param.default + else: + # in case they use something like **kwargs, though this isn't a perfect solution + default = inspect._empty if option.required else None + + new_parameters.append( + inspect.Parameter( + str(option.name), + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, + annotation=annotation, + ) + ) + + new_parameters.extend( + inspect.Parameter( + str(remaining_param.name), + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=remaining_param.default, + annotation=remaining_param.annotation, + ) + for remaining_param in old_params.values() + if _check_if_annotation(remaining_param.annotation, NoArgumentConverter) + ) + + prefixed_cmd = _HybridPrefixedCommand( + name=str(cmd.sub_cmd_name) if cmd.is_subcommand else str(cmd.name), + aliases=list(cmd.sub_cmd_name.to_locale_dict().values()) + if cmd.is_subcommand + else list(cmd.name.to_locale_dict().values()), + help=str(cmd.sub_cmd_description) if cmd.is_subcommand else str(cmd.description), + callback=cmd.callback, + extension=cmd.extension, + inspect_signature=inspect.Signature(new_parameters), # type: ignore + ) + + if cmd.has_binding: + prefixed_cmd._binding = cmd._binding + + if not cmd.is_subcommand: + # these mean nothing in subcommands + if cmd.scopes != [GLOBAL_SCOPE]: + prefixed_cmd.checks.append(_generate_scope_check(cmd.scopes)) + if cmd.default_member_permissions: + prefixed_cmd.checks.append(_generate_permission_check(cmd.default_member_permissions)) + if cmd.dm_permission is False: + prefixed_cmd.checks.append(_guild_check) + + return prefixed_cmd + + +def hybrid_command( + name: str | LocalisedName, + *, + description: Absent[str | LocalisedDesc] = MISSING, + scopes: Absent[list["Snowflake_Type"]] = MISSING, + options: Optional[list[SlashCommandOption | dict]] = None, + default_member_permissions: Optional["Permissions"] = None, + dm_permission: bool = True, + sub_cmd_name: str | LocalisedName = None, + group_name: str | LocalisedName = None, + sub_cmd_description: str | LocalisedDesc = "No Description Set", + group_description: str | LocalisedDesc = "No Description Set", + nsfw: bool = False, +) -> Callable[[Callable[..., Coroutine]], HybridCommand]: + """ + A decorator to declare a coroutine as a hybrid command. + + Hybrid commands are a slash command that can also function as a prefixed command. + These use a HybridContext instead of an InteractionContext, but otherwise are mostly identical to normal slash commands. + + Note that hybrid commands do not support autocompletes. + They also only partially support attachments, allowing one attachment option for a command. + + note: + While the base and group descriptions aren't visible in the discord client, currently. + We strongly advise defining them anyway, if you're using subcommands, as Discord has said they will be visible in + one of the future ui updates. + They are also visible as the description for their prefixed command counterparts. + + Args: + name: 1-32 character name of the command + description: 1-100 character description of the command + scopes: The scope this command exists within + options: The parameters for the command, max 25 + default_member_permissions: What permissions members need to have by default to use this command. + dm_permission: Should this command be available in DMs. + sub_cmd_name: 1-32 character name of the subcommand + sub_cmd_description: 1-100 character description of the subcommand + group_name: 1-32 character name of the group + group_description: 1-100 character description of the group + nsfw: This command should only work in NSFW channels + + Returns: + HybridCommand Object + + """ + + def wrapper(func: Callable[..., Coroutine]) -> HybridCommand: + if not asyncio.iscoroutinefunction(func): + raise ValueError("Commands must be coroutines") + + perm = default_member_permissions + if hasattr(func, "default_member_permissions"): + if perm: + perm = perm | func.default_member_permissions + else: + perm = func.default_member_permissions + + _description = description + if _description is MISSING: + _description = func.__doc__ or "No Description Set" + + return HybridCommand( + name=name, + group_name=group_name, + group_description=group_description, + sub_cmd_name=sub_cmd_name, + sub_cmd_description=sub_cmd_description, + description=_description, + scopes=scopes or [GLOBAL_SCOPE], + default_member_permissions=perm, + dm_permission=dm_permission, + callback=func, + options=options, + nsfw=nsfw, + ) + + return wrapper + + +def hybrid_subcommand( + base: str | LocalisedName, + *, + subcommand_group: Optional[str | LocalisedName] = None, + name: Optional[str | LocalisedName] = None, + description: Absent[str | LocalisedDesc] = MISSING, + base_description: Optional[str | LocalisedDesc] = None, + base_desc: Optional[str | LocalisedDesc] = None, + base_default_member_permissions: Optional["Permissions"] = None, + base_dm_permission: bool = True, + subcommand_group_description: Optional[str | LocalisedDesc] = None, + sub_group_desc: Optional[str | LocalisedDesc] = None, + scopes: list["Snowflake_Type"] = None, + options: list[dict] = None, + nsfw: bool = False, +) -> Callable[[Coroutine], HybridCommand]: + """ + A decorator specifically tailored for creating hybrid subcommands. + + See the hybrid_command decorator for more information. + + Args: + base: The name of the base command + subcommand_group: The name of the subcommand group, if any. + name: The name of the subcommand, defaults to the name of the coroutine. + description: The description of the subcommand + base_description: The description of the base command + base_desc: An alias of `base_description` + base_default_member_permissions: What permissions members need to have by default to use this command. + base_dm_permission: Should this command be available in DMs. + subcommand_group_description: Description of the subcommand group + sub_group_desc: An alias for `subcommand_group_description` + scopes: The scopes of which this command is available, defaults to GLOBAL_SCOPE + options: The options for this command + nsfw: This command should only work in NSFW channels + + Returns: + A HybridCommand object + + """ + + def wrapper(func) -> HybridCommand: + if not asyncio.iscoroutinefunction(func): + raise ValueError("Commands must be coroutines") + + _description = description + if _description is MISSING: + _description = func.__doc__ or "No Description Set" + + return HybridCommand( + name=base, + description=(base_description or base_desc) or "No Description Set", + group_name=subcommand_group, + group_description=(subcommand_group_description or sub_group_desc) or "No Description Set", + sub_cmd_name=name, + sub_cmd_description=_description, + default_member_permissions=base_default_member_permissions, + dm_permission=base_dm_permission, + scopes=scopes or [GLOBAL_SCOPE], + callback=func, + options=options, + nsfw=nsfw, + ) + + return wrapper diff --git a/naff/models/naff/localisation.py b/naff/models/naff/localisation.py index 1326e92df..f6d0e7d98 100644 --- a/naff/models/naff/localisation.py +++ b/naff/models/naff/localisation.py @@ -120,6 +120,9 @@ def to_locale_dict(self) -> dict: if "locale-code" in attr.metadata: if val := getattr(self, attr.name): data[attr.metadata["locale-code"]] = val + + if not data: + data = None # handle discord being stupid return data diff --git a/naff/models/naff/prefixed_commands.py b/naff/models/naff/prefixed_commands.py index 4cae09670..b3c114665 100644 --- a/naff/models/naff/prefixed_commands.py +++ b/naff/models/naff/prefixed_commands.py @@ -41,6 +41,8 @@ class PrefixedCommandParameter: type: Type = attrs.field( default=None, metadata=docs("The type of the parameter.") ) # yes i can use type here, mkdocs doesnt like that + kind: inspect._ParameterKind = attrs.field(default=inspect.Parameter.POSITIONAL_OR_KEYWORD) + """The kind of parameter this is as related to the function.""" converters: list[Callable[["PrefixedContext", str], Any]] = attrs.field( factory=list, metadata=docs("A list of the converter functions for the parameter that convert to its type.") ) @@ -288,6 +290,7 @@ class PrefixedCommand(BaseCommand): metadata=docs("A dict of all subcommands for the command."), factory=dict ) _usage: Optional[str] = field(default=None) + _inspect_signature: Optional[inspect.Signature] = field(default=None) def __attrs_post_init__(self) -> None: super().__attrs_post_init__() # we want checks to work @@ -410,13 +413,17 @@ def _parse_parameters(self) -> None: # clear out old parameters just in case self.parameters = [] - # we don't care about the ctx or self variables - if self.has_binding: - callback = functools.partial(self.callback, None, None) - else: - callback = functools.partial(self.callback, None) + if not self._inspect_signature: + # we don't care about the ctx or self variables + if self.has_binding: + callback = functools.partial(self.callback, None, None) + else: + callback = functools.partial(self.callback, None) + + self._inspect_signature = inspect.signature(callback) + + params = self._inspect_signature.parameters - params = inspect.signature(callback).parameters # this is used by keyword-only and variable args to make sure there isn't more than one of either # mind you, we also don't want one keyword-only and one variable arg either finished_params = False @@ -427,6 +434,7 @@ def _parse_parameters(self) -> None: cmd_param = PrefixedCommandParameter() cmd_param.name = name + cmd_param.kind = param.kind cmd_param.default = param.default if param.default is not param.empty else MISSING cmd_param.type = anno = param.annotation @@ -646,7 +654,10 @@ async def call_callback(self, callback: Callable, ctx: "PrefixedContext") -> Non break converted, used_default = await _convert(param, ctx, arg) - if not param.consume_rest: + if param.kind in { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.VAR_POSITIONAL, + }: new_args.append(converted) else: kwargs[param.name] = converted @@ -669,7 +680,10 @@ async def call_callback(self, callback: Callable, ctx: "PrefixedContext") -> Non if not param.optional: raise BadArgument(f"{param.name} is a required argument that is missing.") else: - if not param.consume_rest: + if param.kind in { + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.VAR_POSITIONAL, + }: new_args.append(param.default) else: kwargs[param.name] = param.default diff --git a/naff/models/naff/tasks/task.py b/naff/models/naff/tasks/task.py index 8bc77c7af..12d802d58 100644 --- a/naff/models/naff/tasks/task.py +++ b/naff/models/naff/tasks/task.py @@ -59,9 +59,12 @@ def on_error(self, error: Exception) -> None: async def __call__(self) -> None: try: if inspect.iscoroutinefunction(self.callback): - await self.callback() + val = await self.callback() else: - self.callback() + val = self.callback() + + if isinstance(val, BaseTrigger): + self.trigger = val except Exception as e: self.on_error(e) diff --git a/pyproject.toml b/pyproject.toml index cbcc684e5..72bb814ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "naff" -version = "1.8.1" +version = "1.9.0" description = "Not another freaking fork" authors = ["LordOfPolls "] diff --git a/setup.py b/setup.py index cff88fc1c..ecb16f8b6 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,11 @@ with open("pyproject.toml", "rb") as f: pyproject = tomli.load(f) -extras_require = {"voice": ["PyNaCl>=1.5.0,<1.6"], "speedup": ["cchardet", "aiodns", "orjson", "Brotli"]} +extras_require = { + "voice": ["PyNaCl>=1.5.0,<1.6"], + "speedup": ["cchardet", "aiodns", "orjson", "Brotli"], + "sentry": ["sentry-sdk"], +} extras_require["all"] = list(itertools.chain.from_iterable(extras_require.values())) extras_require["docs"] = extras_require["all"] + [ "pytkdocs @ git+https://github.com/LordOfPolls/pytkdocs.git", diff --git a/tests/test_protocols.py b/tests/test_protocols.py index a9a09370e..4d02b38b5 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -1,4 +1,4 @@ -from naff import InteractionContext, PrefixedContext, SendableContext +from naff import InteractionContext, PrefixedContext, HybridContext, SendableContext from typeguard import check_type __all__ = () @@ -7,3 +7,4 @@ def test_sendable_context() -> None: check_type("prefixed_context", PrefixedContext, SendableContext) check_type("interaction_context", InteractionContext, SendableContext) + check_type("hybrid_context", HybridContext, SendableContext)