From 2c85e39f415e9d96867dc317592582b685888ed9 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Thu, 16 Nov 2023 15:37:08 +0100 Subject: [PATCH 1/7] feat(commands): don't parse self/ctx parameter annotations of prefix command callbacks (#847) --- changelog/847.feature.rst | 1 + disnake/ext/commands/core.py | 40 +---------- disnake/ext/commands/help.py | 10 --- disnake/ext/commands/params.py | 30 +++++--- disnake/utils.py | 70 ++++++++++++++++++- tests/ext/commands/test_params.py | 110 +++++++++++++++++------------- tests/test_utils.py | 82 ++++++++++++++++++++++ 7 files changed, 235 insertions(+), 108 deletions(-) create mode 100644 changelog/847.feature.rst diff --git a/changelog/847.feature.rst b/changelog/847.feature.rst new file mode 100644 index 0000000000..7418ed0783 --- /dev/null +++ b/changelog/847.feature.rst @@ -0,0 +1 @@ +|commands| Skip evaluating annotations of ``self`` (if present) and ``ctx`` parameters in prefix commands. These may now use stringified annotations with types that aren't available at runtime. diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index 2bb108e966..fda34b5a95 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -381,7 +381,7 @@ def callback(self, function: CommandCallback[CogT, Any, P, T]) -> None: except AttributeError: globalns = {} - params = get_signature_parameters(function, globalns) + params = get_signature_parameters(function, globalns, skip_standard_params=True) for param in params.values(): if param.annotation is Greedy: raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") @@ -607,21 +607,7 @@ def clean_params(self) -> Dict[str, inspect.Parameter]: Useful for inspecting signature. """ - result = self.params.copy() - if self.cog is not None: - # first parameter is self - try: - del result[next(iter(result))] - except StopIteration: - raise ValueError("missing 'self' parameter") from None - - try: - # first/second parameter is context - del result[next(iter(result))] - except StopIteration: - raise ValueError("missing 'context' parameter") from None - - return result + return self.params.copy() @property def full_parent_name(self) -> str: @@ -693,27 +679,7 @@ async def _parse_arguments(self, ctx: Context) -> None: kwargs = ctx.kwargs view = ctx.view - iterator = iter(self.params.items()) - - if self.cog is not None: - # we have 'self' as the first parameter so just advance - # the iterator and resume parsing - try: - next(iterator) - except StopIteration: - raise disnake.ClientException( - f'Callback for {self.name} command is missing "self" parameter.' - ) from None - - # next we have the 'ctx' as the next parameter - try: - next(iterator) - except StopIteration: - raise disnake.ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) from None - - for name, param in iterator: + for name, param in self.params.items(): ctx.current_parameter = param if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): transformed = await self.transform(ctx, param) diff --git a/disnake/ext/commands/help.py b/disnake/ext/commands/help.py index 5841a8ba11..483d4f4bd2 100644 --- a/disnake/ext/commands/help.py +++ b/disnake/ext/commands/help.py @@ -202,16 +202,6 @@ async def _parse_arguments(self, ctx) -> None: async def _on_error_cog_implementation(self, dummy, ctx, error) -> None: await self._injected.on_help_command_error(ctx, error) - @property - def clean_params(self): - result = self.params.copy() - try: - del result[next(iter(result))] - except StopIteration: - raise ValueError("Missing context parameter") from None - else: - return result - def _inject_into_cog(self, cog) -> None: # Warning: hacky diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 5aae2de611..95679ed802 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -42,7 +42,12 @@ from disnake.ext import commands from disnake.i18n import Localized from disnake.interactions import ApplicationCommandInteraction -from disnake.utils import get_signature_parameters, get_signature_return, maybe_coroutine +from disnake.utils import ( + get_signature_parameters, + get_signature_return, + maybe_coroutine, + signature_has_self_param, +) from . import errors from .converter import CONVERTER_MAPPING @@ -771,7 +776,7 @@ def parse_converter_annotation(self, converter: Callable, fallback_annotation: A # (we need `__call__` here to get the correct global namespace later, since # classes do not have `__globals__`) converter_func = converter.__call__ - _, parameters = isolate_self(get_signature_parameters(converter_func)) + _, parameters = isolate_self(converter_func) if len(parameters) != 1: raise TypeError( @@ -879,9 +884,16 @@ def safe_call(function: Callable[..., T], /, *possible_args: Any, **possible_kwa def isolate_self( - parameters: Dict[str, inspect.Parameter], + function: Callable, + parameters: Optional[Dict[str, inspect.Parameter]] = None, ) -> Tuple[Tuple[Optional[inspect.Parameter], ...], Dict[str, inspect.Parameter]]: - """Create parameters without self and the first interaction""" + """Create parameters without self and the first interaction. + + Optionally accepts a `{str: inspect.Parameter}` dict as an optimization, + calls `get_signature_parameters(function)` if not provided. + """ + if parameters is None: + parameters = get_signature_parameters(function) if not parameters: return (None, None), {} @@ -891,7 +903,7 @@ def isolate_self( cog_param: Optional[inspect.Parameter] = None inter_param: Optional[inspect.Parameter] = None - if parametersl[0].name == "self": + if signature_has_self_param(function): cog_param = parameters.pop(parametersl[0].name) parametersl.pop(0) if parametersl: @@ -941,15 +953,11 @@ def collect_params( ) -> Tuple[Optional[str], Optional[str], List[ParamInfo], Dict[str, Injection]]: """Collect all parameters in a function. - Optionally accepts a `{str: inspect.Parameter}` dict as an optimization, - calls `get_signature_parameters(function)` if not provided. + Optionally accepts a `{str: inspect.Parameter}` dict as an optimization. Returns: (`cog parameter`, `interaction parameter`, `param infos`, `injections`) """ - if parameters is None: - parameters = get_signature_parameters(function) - - (cog_param, inter_param), parameters = isolate_self(parameters) + (cog_param, inter_param), parameters = isolate_self(function, parameters) doc = disnake.utils.parse_docstring(function)["params"] diff --git a/disnake/utils.py b/disnake/utils.py index d40cd4e8fe..a74d50ab94 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -12,6 +12,7 @@ import pkgutil import re import sys +import types import unicodedata import warnings from base64 import b64encode @@ -1227,7 +1228,10 @@ def _get_function_globals(function: Callable[..., Any]) -> Dict[str, Any]: def get_signature_parameters( - function: Callable[..., Any], globalns: Optional[Dict[str, Any]] = None + function: Callable[..., Any], + globalns: Optional[Dict[str, Any]] = None, + *, + skip_standard_params: bool = False, ) -> Dict[str, inspect.Parameter]: # if no globalns provided, unwrap (where needed) and get global namespace from there if globalns is None: @@ -1237,9 +1241,23 @@ def get_signature_parameters( cache: Dict[str, Any] = {} signature = inspect.signature(function) + iterator = iter(signature.parameters.items()) + + if skip_standard_params: + # skip `self` (if present) and `ctx` parameters, + # since their annotations are irrelevant + skip = 2 if signature_has_self_param(function) else 1 + + for _ in range(skip): + try: + next(iterator) + except StopIteration: + raise ValueError( + f"Expected command callback to have at least {skip} parameter(s)" + ) from None # eval all parameter annotations - for name, parameter in signature.parameters.items(): + for name, parameter in iterator: annotation = parameter.annotation if annotation is _inspect_empty: params[name] = parameter @@ -1270,6 +1288,54 @@ def get_signature_return(function: Callable[..., Any]) -> Any: return ret +def signature_has_self_param(function: Callable[..., Any]) -> bool: + # If a function was defined in a class and is not bound (i.e. is not types.MethodType), + # it should have a `self` parameter. + # Bound methods technically also have a `self` parameter, but this is + # used in conjunction with `inspect.signature`, which drops that parameter. + # + # There isn't really any way to reliably detect whether a function + # was defined in a class, other than `__qualname__`, thanks to PEP 3155. + # As noted in the PEP, this doesn't work with rebinding, but that should be a pretty rare edge case. + # + # + # There are a few possible situations here - for the purposes of this method, + # we want to detect the first case only: + # (1) The preceding component for *methods in classes* will be the class name, resulting in `Clazz.func`. + # (2) For *unbound* functions (not methods), `__qualname__ == __name__`. + # (3) Bound methods (i.e. types.MethodType) don't have a `self` parameter in the context of this function (see first paragraph). + # (we currently don't expect to handle bound methods anywhere, except the default help command implementation). + # (4) A somewhat special case are lambdas defined in a class namespace (but not inside a method), which use `Clazz.` and shouldn't match (1). + # (lambdas at class level are a bit funky; we currently only expect them in the `Param(converter=)` kwarg, which doesn't take a `self` parameter). + # (5) Similarly, *nested functions* use `containing_func..func` and shouldn't have a `self` parameter. + # + # Working solely based on this string is certainly not ideal, + # but the compiler does a bunch of processing just for that attribute, + # and there's really no other way to retrieve this information through other means later. + # (3.10: https://github.com/python/cpython/blob/e07086db03d2dc1cd2e2a24f6c9c0ddd422b4cf0/Python/compile.c#L744) + # + # Not reliable for classmethod/staticmethod. + + qname = function.__qualname__ + if qname == function.__name__: + # (2) + return False + + if isinstance(function, types.MethodType): + # (3) + return False + + # "a.b.c.d" => "a.b.c", "d" + parent, basename = qname.rsplit(".", 1) + + if basename == "": + # (4) + return False + + # (5) + return not parent.endswith(".") + + TimestampStyle = Literal["f", "F", "d", "D", "t", "T", "R"] diff --git a/tests/ext/commands/test_params.py b/tests/ext/commands/test_params.py index 61c812c8f0..96f2c08c32 100644 --- a/tests/ext/commands/test_params.py +++ b/tests/ext/commands/test_params.py @@ -10,7 +10,6 @@ import disnake from disnake import Member, Role, User from disnake.ext import commands -from disnake.ext.commands import params OptionType = disnake.OptionType @@ -67,53 +66,6 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None: with pytest.raises(commands.errors.MemberNotFound): await info.verify_type(mock.Mock(), arg_mock) - def test_isolate_self(self) -> None: - def func(a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_inter(self) -> None: - def func(i: disnake.ApplicationCommandInteraction, a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_cog_inter(self) -> None: - def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is not None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_generic(self) -> None: - def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - - def test_isolate_self_union(self) -> None: - def func( - i: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], a: int - ) -> None: - ... - - (cog, inter), parameters = params.isolate_self(params.get_signature_parameters(func)) - assert cog is None - assert inter is not None - assert parameters == ({"a": mock.ANY}) - # this uses `Range` for testing `_BaseRange`, `String` should work equally class TestBaseRange: @@ -260,3 +212,65 @@ def test_optional(self, annotation_str) -> None: assert info.min_value == 1 assert info.max_value == 2 assert info.type == int + + +class TestIsolateSelf: + def test_function_simple(self) -> None: + def func(a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None + assert inter is None + assert params.keys() == {"a"} + + def test_function_inter(self) -> None: + def func(inter: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None # should not be set + assert inter is not None + assert params.keys() == {"a"} + + def test_unbound_method(self) -> None: + class Cog(commands.Cog): + def func(self, inter: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(Cog.func) + assert cog is not None # *should* be set here + assert inter is not None + assert params.keys() == {"a"} + + # I don't think the param parsing logic ever handles bound methods, but testing for regressions anyway + def test_bound_method(self) -> None: + class Cog(commands.Cog): + def func(self, inter: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(Cog().func) + assert cog is None # should not be set here, since method is already bound + assert inter is not None + assert params.keys() == {"a"} + + def test_generic(self) -> None: + def func(inter: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None + assert inter is not None + assert params.keys() == {"a"} + + def test_inter_union(self) -> None: + def func( + inter: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], + a: int, + ) -> None: + ... + + (cog, inter), params = commands.params.isolate_self(func) + assert cog is None + assert inter is not None + assert params.keys() == {"a"} diff --git a/tests/test_utils.py b/tests/test_utils.py index a8f52e6b1f..d767264a95 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import asyncio import datetime +import functools import inspect import os import sys @@ -880,3 +881,84 @@ def test_as_valid_locale(locale, expected) -> None: ) def test_humanize_list(values, expected) -> None: assert utils.humanize_list(values, "plus") == expected + + +# used for `test_signature_has_self_param` +def _toplevel(): + def inner() -> None: + ... + + return inner + + +def decorator(f): + @functools.wraps(f) + def wrap(self, *args, **kwargs): + return f(self, *args, **kwargs) + + return wrap + + +# used for `test_signature_has_self_param` +class _Clazz: + def func(self): + def inner() -> None: + ... + + return inner + + @classmethod + def cmethod(cls) -> None: + ... + + @staticmethod + def smethod() -> None: + ... + + class Nested: + def func(self): + def inner() -> None: + ... + + return inner + + rebind = _toplevel + + @decorator + def decorated(self) -> None: + ... + + _lambda = lambda: None + + +@pytest.mark.parametrize( + ("function", "expected"), + [ + # top-level function + (_toplevel, False), + # methods in class + (_Clazz.func, True), + (_Clazz().func, False), + # unfortunately doesn't work + (_Clazz.rebind, False), + (_Clazz().rebind, False), + # classmethod/staticmethod isn't supported, but checked to ensure consistency + (_Clazz.cmethod, False), + (_Clazz.smethod, True), + # nested class methods + (_Clazz.Nested.func, True), + (_Clazz.Nested().func, False), + # inner methods + (_toplevel(), False), + (_Clazz().func(), False), + (_Clazz.Nested().func(), False), + # decorated method + (_Clazz.decorated, True), + (_Clazz().decorated, False), + # lambda (class-level) + (_Clazz._lambda, False), + (_Clazz()._lambda, False), + ], +) +def test_signature_has_self_param(function, expected) -> None: + assert utils.signature_has_self_param(function) == expected From 4bd0d25ac787adb405d270ad6236245f82a1fd07 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:09:17 +0100 Subject: [PATCH 2/7] build(deps): update pyright to v1.1.336 (#1122) --- disnake/abc.py | 2 +- disnake/activity.py | 2 +- disnake/asset.py | 2 +- disnake/audit_logs.py | 2 +- disnake/channel.py | 12 ++--- disnake/client.py | 32 ++++++++--- disnake/components.py | 48 +++++++++++------ disnake/emoji.py | 2 +- disnake/enums.py | 4 +- disnake/ext/commands/base_core.py | 2 +- disnake/ext/commands/bot_base.py | 17 ++---- disnake/ext/commands/converter.py | 6 +-- disnake/ext/commands/cooldowns.py | 2 +- disnake/ext/commands/core.py | 4 +- disnake/ext/commands/help.py | 6 ++- disnake/ext/commands/params.py | 7 +-- disnake/ext/commands/slash_core.py | 4 +- disnake/ext/tasks/__init__.py | 4 +- disnake/gateway.py | 2 +- disnake/guild.py | 4 -- disnake/http.py | 25 ++++----- disnake/i18n.py | 2 +- disnake/interactions/base.py | 2 +- disnake/iterators.py | 10 ++-- disnake/message.py | 7 +-- disnake/shard.py | 3 +- disnake/state.py | 8 ++- disnake/types/audit_log.py | 84 ++++++++++++++--------------- disnake/types/automod.py | 4 +- disnake/types/template.py | 2 +- disnake/ui/action_row.py | 5 +- disnake/ui/button.py | 2 +- disnake/ui/item.py | 2 +- disnake/ui/modal.py | 2 +- disnake/ui/select/channel.py | 2 +- disnake/ui/select/mentionable.py | 2 +- disnake/ui/select/role.py | 2 +- disnake/ui/select/string.py | 2 +- disnake/ui/select/user.py | 2 +- disnake/utils.py | 12 +++-- disnake/voice_client.py | 2 +- docs/extensions/builder.py | 2 +- examples/basic_voice.py | 4 +- examples/interactions/injections.py | 2 +- examples/interactions/modal.py | 2 + pyproject.toml | 2 +- test_bot/cogs/modals.py | 2 +- tests/ui/test_decorators.py | 8 +-- 48 files changed, 193 insertions(+), 175 deletions(-) diff --git a/disnake/abc.py b/disnake/abc.py index 605bb725aa..b9a60f3ee5 100644 --- a/disnake/abc.py +++ b/disnake/abc.py @@ -390,7 +390,7 @@ async def _edit( if p_id is not None and (parent := self.guild.get_channel(p_id)): overwrites_payload = [c._asdict() for c in parent._overwrites] - if overwrites is not MISSING and overwrites is not None: + if overwrites not in (MISSING, None): overwrites_payload = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): diff --git a/disnake/activity.py b/disnake/activity.py index 92460cd35d..3c290edd17 100644 --- a/disnake/activity.py +++ b/disnake/activity.py @@ -921,7 +921,7 @@ def create_activity( elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data: activity = Spotify(**data) else: - activity = Activity(**data) + activity = Activity(**data) # type: ignore if isinstance(activity, (Activity, CustomActivity)) and activity.emoji and state: activity.emoji._state = state diff --git a/disnake/asset.py b/disnake/asset.py index fad72c79ce..bc7b505697 100644 --- a/disnake/asset.py +++ b/disnake/asset.py @@ -24,7 +24,7 @@ ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] AnyState = Union[ConnectionState, _WebhookState[BaseWebhook]] -AssetBytes = Union[bytes, "AssetMixin"] +AssetBytes = Union[utils._BytesLike, "AssetMixin"] VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} diff --git a/disnake/audit_logs.py b/disnake/audit_logs.py index e8ab022edf..256aaa04dc 100644 --- a/disnake/audit_logs.py +++ b/disnake/audit_logs.py @@ -245,7 +245,7 @@ def _transform_datetime(entry: AuditLogEntry, data: Optional[str]) -> Optional[d def _transform_privacy_level( - entry: AuditLogEntry, data: int + entry: AuditLogEntry, data: Optional[int] ) -> Optional[Union[enums.StagePrivacyLevel, enums.GuildScheduledEventPrivacyLevel]]: if data is None: return None diff --git a/disnake/channel.py b/disnake/channel.py index 7eef52b942..ffb11f2d2c 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -473,7 +473,7 @@ async def edit( overwrites=overwrites, flags=flags, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -1628,7 +1628,7 @@ async def edit( slowmode_delay=slowmode_delay, flags=flags, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -2453,7 +2453,7 @@ async def edit( flags=flags, slowmode_delay=slowmode_delay, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -2946,7 +2946,7 @@ async def edit( overwrites=overwrites, flags=flags, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -3619,7 +3619,7 @@ async def edit( default_sort_order=default_sort_order, default_layout=default_layout, reason=reason, - **kwargs, + **kwargs, # type: ignore ) if payload is not None: # the payload will always be the proper channel payload @@ -3994,7 +3994,7 @@ async def create_thread( stickers=stickers, ) - if auto_archive_duration is not None: + if auto_archive_duration not in (MISSING, None): auto_archive_duration = cast( "ThreadArchiveDurationLiteral", try_enum_to_int(auto_archive_duration) ) diff --git a/disnake/client.py b/disnake/client.py index f71842c7b3..b25b44cbd9 100644 --- a/disnake/client.py +++ b/disnake/client.py @@ -25,6 +25,7 @@ Optional, Sequence, Tuple, + TypedDict, TypeVar, Union, overload, @@ -79,6 +80,8 @@ from .widget import Widget if TYPE_CHECKING: + from typing_extensions import NotRequired + from .abc import GuildChannel, PrivateChannel, Snowflake, SnowflakeTime from .app_commands import APIApplicationCommand from .asset import AssetBytes @@ -207,6 +210,17 @@ class GatewayParams(NamedTuple): zlib: bool = True +# used for typing the ws parameter dict in the connect() loop +class _WebSocketParams(TypedDict): + initial: bool + shard_id: Optional[int] + gateway: Optional[str] + + sequence: NotRequired[Optional[int]] + resume: NotRequired[bool] + session: NotRequired[Optional[str]] + + class Client: """Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -1080,7 +1094,7 @@ async def connect( if not ignore_session_start_limit and self.session_start_limit.remaining == 0: raise SessionStartLimitReached(self.session_start_limit) - ws_params = { + ws_params: _WebSocketParams = { "initial": True, "shard_id": self.shard_id, "gateway": initial_gateway, @@ -1104,6 +1118,7 @@ async def connect( while True: await self.ws.poll_event() + except ReconnectWebSocket as e: _log.info("Got a request to %s the websocket.", e.op) self.dispatch("disconnect") @@ -1116,6 +1131,7 @@ async def connect( gateway=self.ws.resume_gateway if e.resume else initial_gateway, ) continue + except ( OSError, HTTPException, @@ -1196,7 +1212,8 @@ async def close(self) -> None: # if an error happens during disconnects, disregard it. pass - if self.ws is not None and self.ws.open: + # can be None if not connected + if self.ws is not None and self.ws.open: # pyright: ignore[reportUnnecessaryComparison] await self.ws.close(code=1000) await self.http.close() @@ -1874,16 +1891,15 @@ async def change_presence( await self.ws.change_presence(activity=activity, status=status_str) + activities = () if activity is None else (activity,) for guild in self._connection.guilds: me = guild.me - if me is None: + if me is None: # pyright: ignore[reportUnnecessaryComparison] + # may happen if guild is unavailable continue - if activity is not None: - me.activities = (activity,) # type: ignore - else: - me.activities = () - + # Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] + me.activities = activities # type: ignore me.status = status # Guild stuff diff --git a/disnake/components.py b/disnake/components.py index e6f3d14904..7614fd424b 100644 --- a/disnake/components.py +++ b/disnake/components.py @@ -9,6 +9,7 @@ Dict, Generic, List, + Literal, Optional, Tuple, Type, @@ -22,11 +23,12 @@ from .utils import MISSING, assert_never, get_slots if TYPE_CHECKING: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias from .emoji import Emoji from .types.components import ( ActionRow as ActionRowPayload, + AnySelectMenu as AnySelectMenuPayload, BaseSelectMenu as BaseSelectMenuPayload, ButtonComponent as ButtonComponentPayload, ChannelSelectMenu as ChannelSelectMenuPayload, @@ -63,12 +65,16 @@ "MentionableSelectMenu", "ChannelSelectMenu", ] -MessageComponent = Union["Button", "AnySelectMenu"] -if TYPE_CHECKING: # TODO: remove when we add modal select support - from typing_extensions import TypeAlias +SelectMenuType = Literal[ + ComponentType.string_select, + ComponentType.user_select, + ComponentType.role_select, + ComponentType.mentionable_select, + ComponentType.channel_select, +] -# ModalComponent = Union["TextInput", "AnySelectMenu"] +MessageComponent = Union["Button", "AnySelectMenu"] ModalComponent: TypeAlias = "TextInput" NestedComponent = Union[MessageComponent, ModalComponent] @@ -131,8 +137,6 @@ class ActionRow(Component, Generic[ComponentT]): Attributes ---------- - type: :class:`ComponentType` - The type of component. children: List[Union[:class:`Button`, :class:`BaseSelectMenu`, :class:`TextInput`]] The children components that this holds, if any. """ @@ -142,10 +146,9 @@ class ActionRow(Component, Generic[ComponentT]): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ActionRowPayload) -> None: - self.type: ComponentType = try_enum(ComponentType, data["type"]) - self.children: List[ComponentT] = [ - _component_factory(d) for d in data.get("components", []) - ] + self.type: Literal[ComponentType.action_row] = ComponentType.action_row + children = [_component_factory(d) for d in data.get("components", [])] + self.children: List[ComponentT] = children # type: ignore def to_dict(self) -> ActionRowPayload: return { @@ -195,7 +198,7 @@ class Button(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ButtonComponentPayload) -> None: - self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.type: Literal[ComponentType.button] = ComponentType.button self.style: ButtonStyle = try_enum(ButtonStyle, data["style"]) self.custom_id: Optional[str] = data.get("custom_id") self.url: Optional[str] = data.get("url") @@ -209,7 +212,7 @@ def __init__(self, data: ButtonComponentPayload) -> None: def to_dict(self) -> ButtonComponentPayload: payload: ButtonComponentPayload = { - "type": 2, + "type": self.type.value, "style": self.style.value, "disabled": self.disabled, } @@ -273,8 +276,13 @@ class BaseSelectMenu(Component): __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ - def __init__(self, data: BaseSelectMenuPayload) -> None: - self.type: ComponentType = try_enum(ComponentType, data["type"]) + # n.b: ideally this would be `BaseSelectMenuPayload`, + # but pyright made TypedDict keys invariant and doesn't + # fully support readonly items yet (which would help avoid this) + def __init__(self, data: AnySelectMenuPayload) -> None: + component_type = try_enum(ComponentType, data["type"]) + self.type: SelectMenuType = component_type # type: ignore + self.custom_id: str = data["custom_id"] self.placeholder: Optional[str] = data.get("placeholder") self.min_values: int = data.get("min_values", 1) @@ -329,6 +337,7 @@ class StringSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = ("options",) __repr_info__: ClassVar[Tuple[str, ...]] = BaseSelectMenu.__repr_info__ + __slots__ + type: Literal[ComponentType.string_select] def __init__(self, data: StringSelectMenuPayload) -> None: super().__init__(data) @@ -372,6 +381,8 @@ class UserSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = () + type: Literal[ComponentType.user_select] + if TYPE_CHECKING: def to_dict(self) -> UserSelectMenuPayload: @@ -405,6 +416,8 @@ class RoleSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = () + type: Literal[ComponentType.role_select] + if TYPE_CHECKING: def to_dict(self) -> RoleSelectMenuPayload: @@ -438,6 +451,8 @@ class MentionableSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = () + type: Literal[ComponentType.mentionable_select] + if TYPE_CHECKING: def to_dict(self) -> MentionableSelectMenuPayload: @@ -475,6 +490,7 @@ class ChannelSelectMenu(BaseSelectMenu): __slots__: Tuple[str, ...] = ("channel_types",) __repr_info__: ClassVar[Tuple[str, ...]] = BaseSelectMenu.__repr_info__ + __slots__ + type: Literal[ComponentType.channel_select] def __init__(self, data: ChannelSelectMenuPayload) -> None: super().__init__(data) @@ -643,7 +659,7 @@ class TextInput(Component): def __init__(self, data: TextInputPayload) -> None: style = data.get("style", TextInputStyle.short.value) - self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.type: Literal[ComponentType.text_input] = ComponentType.text_input self.custom_id: str = data["custom_id"] self.style: TextInputStyle = try_enum(TextInputStyle, style) self.label: Optional[str] = data.get("label") diff --git a/disnake/emoji.py b/disnake/emoji.py index 0f3d02c27d..2a24877b07 100644 --- a/disnake/emoji.py +++ b/disnake/emoji.py @@ -151,7 +151,7 @@ def roles(self) -> List[Role]: and count towards a separate limit of 25 emojis. """ guild = self.guild - if guild is None: + if guild is None: # pyright: ignore[reportUnnecessaryComparison] return [] return [role for role in guild.roles if self._roles.has(role.id)] diff --git a/disnake/enums.py b/disnake/enums.py index b4bf3d994d..cb603c5425 100644 --- a/disnake/enums.py +++ b/disnake/enums.py @@ -466,7 +466,7 @@ def category(self) -> Optional[AuditLogActionCategory]: @property def target_type(self) -> Optional[str]: v = self.value - if v == -1: + if v == -1: # pyright: ignore[reportUnnecessaryComparison] return "all" elif v < 10: return "guild" @@ -627,7 +627,7 @@ class ComponentType(Enum): action_row = 1 button = 2 string_select = 3 - select = string_select # backwards compatibility + select = 3 # backwards compatibility text_input = 4 user_select = 5 role_select = 6 diff --git a/disnake/ext/commands/base_core.py b/disnake/ext/commands/base_core.py index 7198394be8..3599ea0908 100644 --- a/disnake/ext/commands/base_core.py +++ b/disnake/ext/commands/base_core.py @@ -303,7 +303,7 @@ def _prepare_cooldowns(self, inter: ApplicationCommandInteraction) -> None: dt = inter.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() bucket = self._buckets.get_bucket(inter, current) # type: ignore - if bucket is not None: + if bucket is not None: # pyright: ignore[reportUnnecessaryComparison] retry_after = bucket.update_rate_limit(current) if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore diff --git a/disnake/ext/commands/bot_base.py b/disnake/ext/commands/bot_base.py index d55dc63490..1bba906c82 100644 --- a/disnake/ext/commands/bot_base.py +++ b/disnake/ext/commands/bot_base.py @@ -10,18 +10,7 @@ import sys import traceback import warnings -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - List, - Optional, - Type, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Optional, Type, TypeVar, Union import disnake @@ -414,7 +403,7 @@ def _remove_module_references(self, name: str) -> None: super()._remove_module_references(name) # remove all the commands from the module for cmd in self.all_commands.copy().values(): - if cmd.module is not None and _is_submodule(name, cmd.module): + if cmd.module and _is_submodule(name, cmd.module): if isinstance(cmd, GroupMixin): cmd.recursively_remove_all_commands() self.remove_command(cmd.name) @@ -513,7 +502,7 @@ class be provided, it must be similar enough to :class:`.Context`\'s ``cls`` parameter. """ view = StringView(message.content) - ctx = cast("CXT", cls(prefix=None, view=view, bot=self, message=message)) + ctx = cls(prefix=None, view=view, bot=self, message=message) if message.author.id == self.user.id: # type: ignore return ctx diff --git a/disnake/ext/commands/converter.py b/disnake/ext/commands/converter.py index 8bca2bd6dd..29672b2e54 100644 --- a/disnake/ext/commands/converter.py +++ b/disnake/ext/commands/converter.py @@ -1133,7 +1133,7 @@ def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: - raise TypeError(f"Greedy[{converter.__name__}] is invalid.") # type: ignore + raise TypeError(f"Greedy[{converter.__name__}] is invalid.") if origin is Union and type(None) in args: raise TypeError(f"Greedy[{converter!r}] is invalid.") @@ -1161,7 +1161,7 @@ def get_converter(param: inspect.Parameter) -> Any: return converter -_GenericAlias = type(List[T]) +_GenericAlias = type(List[Any]) def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool: @@ -1222,7 +1222,7 @@ async def _actual_conversion( raise ConversionError(converter, exc) from exc try: - return converter(argument) + return converter(argument) # type: ignore except CommandError: raise except Exception as exc: diff --git a/disnake/ext/commands/cooldowns.py b/disnake/ext/commands/cooldowns.py index 4268f76fff..354754550a 100644 --- a/disnake/ext/commands/cooldowns.py +++ b/disnake/ext/commands/cooldowns.py @@ -228,7 +228,7 @@ def get_bucket(self, message: Message, current: Optional[float] = None) -> Coold key = self._bucket_key(message) if key not in self._cache: bucket = self.create_bucket(message) - if bucket is not None: + if bucket is not None: # pyright: ignore[reportUnnecessaryComparison] self._cache[key] = bucket else: bucket = self._cache[key] diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index fda34b5a95..2ddcb10075 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -755,7 +755,7 @@ def _prepare_cooldowns(self, ctx: Context) -> None: dt = ctx.message.edited_at or ctx.message.created_at current = dt.replace(tzinfo=datetime.timezone.utc).timestamp() bucket = self._buckets.get_bucket(ctx.message, current) - if bucket is not None: + if bucket is not None: # pyright: ignore[reportUnnecessaryComparison] retry_after = bucket.update_rate_limit(current) if retry_after: raise CommandOnCooldown(bucket, retry_after, self._buckets.type) # type: ignore @@ -1718,7 +1718,7 @@ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: decorator.predicate = predicate else: - @functools.wraps(predicate) + @functools.wraps(predicate) # type: ignore async def wrapper(ctx): return predicate(ctx) # type: ignore diff --git a/disnake/ext/commands/help.py b/disnake/ext/commands/help.py index 483d4f4bd2..ecd3988b86 100644 --- a/disnake/ext/commands/help.py +++ b/disnake/ext/commands/help.py @@ -368,7 +368,11 @@ def invoked_with(self): """ command_name = self._command_impl.name ctx = self.context - if ctx is None or ctx.command is None or ctx.command.qualified_name != command_name: + if ( + ctx is disnake.utils.MISSING + or ctx.command is None + or ctx.command.qualified_name != command_name + ): return command_name return ctx.invoked_with diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 95679ed802..9114b8b353 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -85,9 +85,6 @@ if sys.version_info >= (3, 10): from types import EllipsisType, UnionType -elif TYPE_CHECKING: - UnionType = object() - EllipsisType = ellipsis # noqa: F821 else: UnionType = object() EllipsisType = type(Ellipsis) @@ -543,7 +540,7 @@ def __init__( self.max_length = max_length self.large = large - def copy(self) -> ParamInfo: + def copy(self) -> Self: # n. b. this method needs to be manually updated when a new attribute is added. cls = self.__class__ ins = cls.__new__(cls) @@ -1339,7 +1336,7 @@ def option_enum( choices = choices or kwargs first, *_ = choices.values() - return Enum("", choices, type=type(first)) + return Enum("", choices, type=type(first)) # type: ignore class ConverterMethod(classmethod): diff --git a/disnake/ext/commands/slash_core.py b/disnake/ext/commands/slash_core.py index 1b318a21d0..4652c552f8 100644 --- a/disnake/ext/commands/slash_core.py +++ b/disnake/ext/commands/slash_core.py @@ -666,7 +666,7 @@ async def _call_relevant_autocompleter(self, inter: ApplicationCommandInteractio group = self.children.get(chain[0]) if not isinstance(group, SubCommandGroup): raise AssertionError("the first subcommand is not a SubCommandGroup instance") - subcmd = group.children.get(chain[1]) if group is not None else None + subcmd = group.children.get(chain[1]) else: raise ValueError("Command chain is too long") @@ -695,7 +695,7 @@ async def invoke_children(self, inter: ApplicationCommandInteraction) -> None: group = self.children.get(chain[0]) if not isinstance(group, SubCommandGroup): raise AssertionError("the first subcommand is not a SubCommandGroup instance") - subcmd = group.children.get(chain[1]) if group is not None else None + subcmd = group.children.get(chain[1]) else: raise ValueError("Command chain is too long") diff --git a/disnake/ext/tasks/__init__.py b/disnake/ext/tasks/__init__.py index 1c23e0e912..6532c3d088 100644 --- a/disnake/ext/tasks/__init__.py +++ b/disnake/ext/tasks/__init__.py @@ -708,7 +708,7 @@ class Object(Protocol[T_co, P]): def __new__(cls) -> T_co: ... - def __init__(*args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: ... @@ -734,7 +734,7 @@ def loop( def loop( - cls: Type[Object[L_co, Concatenate[LF, P]]] = Loop[LF], + cls: Type[Object[L_co, Concatenate[LF, P]]] = Loop[Any], **kwargs: Any, ) -> Callable[[LF], L_co]: """A decorator that schedules a task in the background for you with diff --git a/disnake/gateway.py b/disnake/gateway.py index 2081493509..cd0cb6d44a 100644 --- a/disnake/gateway.py +++ b/disnake/gateway.py @@ -274,7 +274,7 @@ async def close(self, *, code: int = 4000, message: bytes = b"") -> bool: class HeartbeatWebSocket(Protocol): - HEARTBEAT: Final[Literal[1, 3]] # type: ignore + HEARTBEAT: Final[Literal[1, 3]] thread_id: int loop: asyncio.AbstractEventLoop diff --git a/disnake/guild.py b/disnake/guild.py index 3927992fb5..ba140f2298 100644 --- a/disnake/guild.py +++ b/disnake/guild.py @@ -3136,10 +3136,6 @@ async def integrations(self) -> List[Integration]: def convert(d): factory, _ = _integration_factory(d["type"]) - if factory is None: - raise InvalidData( - "Unknown integration type {type!r} for integration ID {id}".format_map(d) - ) return factory(guild=self, data=d) return [convert(d) for d in data] diff --git a/disnake/http.py b/disnake/http.py index f8c4b44694..06b3801861 100644 --- a/disnake/http.py +++ b/disnake/http.py @@ -248,19 +248,18 @@ def recreate(self) -> None: ) async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse: - kwargs = { - "proxy_auth": self.proxy_auth, - "proxy": self.proxy, - "max_msg_size": 0, - "timeout": 30.0, - "autoclose": False, - "headers": { + return await self.__session.ws_connect( + url, + proxy_auth=self.proxy_auth, + proxy=self.proxy, + max_msg_size=0, + timeout=30.0, + autoclose=False, + headers={ "User-Agent": self.user_agent, }, - "compress": compress, - } - - return await self.__session.ws_connect(url, **kwargs) + compress=compress, + ) async def request( self, @@ -276,9 +275,7 @@ async def request( lock = self._locks.get(bucket) if lock is None: - lock = asyncio.Lock() - if bucket is not None: - self._locks[bucket] = lock + self._locks[bucket] = lock = asyncio.Lock() # header creation headers: Dict[str, str] = { diff --git a/disnake/i18n.py b/disnake/i18n.py index 344787ad5b..c2781a9eb8 100644 --- a/disnake/i18n.py +++ b/disnake/i18n.py @@ -409,7 +409,7 @@ def _load_file(self, path: Path) -> None: except Exception as e: raise RuntimeError(f"Unable to load '{path}': {e}") from e - def _load_dict(self, data: Dict[str, str], locale: str) -> None: + def _load_dict(self, data: Dict[str, Optional[str]], locale: str) -> None: if not isinstance(data, dict) or not all( o is None or isinstance(o, str) for o in data.values() ): diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index bdcbe3cae2..01637be96a 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -1855,7 +1855,7 @@ def __init__( guild and guild.get_channel_or_thread(channel_id) or factory( - guild=guild_fallback, # type: ignore + guild=guild_fallback, state=state, data=channel, # type: ignore ) diff --git a/disnake/iterators.py b/disnake/iterators.py index ea8347effd..f7d694598a 100644 --- a/disnake/iterators.py +++ b/disnake/iterators.py @@ -106,7 +106,7 @@ def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: return _MappedAsyncIterator(self, func) - def filter(self, predicate: _Func[T, bool]) -> _FilteredAsyncIterator[T]: + def filter(self, predicate: Optional[_Func[T, bool]]) -> _FilteredAsyncIterator[T]: return _FilteredAsyncIterator(self, predicate) async def flatten(self) -> List[T]: @@ -152,11 +152,11 @@ async def next(self) -> OT: class _FilteredAsyncIterator(_AsyncIterator[T]): - def __init__(self, iterator: _AsyncIterator[T], predicate: _Func[T, bool]) -> None: + def __init__(self, iterator: _AsyncIterator[T], predicate: Optional[_Func[T, bool]]) -> None: self.iterator = iterator if predicate is None: - predicate = lambda x: bool(x) + predicate = bool # similar to the `filter` builtin, a `None` filter drops falsy items self.predicate: _Func[T, bool] = predicate @@ -626,8 +626,8 @@ async def _fill(self) -> None: } for element in entries: - # TODO: remove this if statement later - if element["action_type"] is None: + # https://github.com/discord/discord-api-docs/issues/5055#issuecomment-1266363766 + if element["action_type"] is None: # pyright: ignore[reportUnnecessaryComparison] continue await self.entries.put( diff --git a/disnake/message.py b/disnake/message.py index 92aba532c7..e3967e1160 100644 --- a/disnake/message.py +++ b/disnake/message.py @@ -658,13 +658,14 @@ def __repr__(self) -> str: return f"" def to_dict(self) -> MessageReferencePayload: - result: MessageReferencePayload = {"channel_id": self.channel_id} + result: MessageReferencePayload = { + "channel_id": self.channel_id, + "fail_if_not_exists": self.fail_if_not_exists, + } if self.message_id is not None: result["message_id"] = self.message_id if self.guild_id is not None: result["guild_id"] = self.guild_id - if self.fail_if_not_exists is not None: - result["fail_if_not_exists"] = self.fail_if_not_exists return result to_message_reference_dict = to_dict diff --git a/disnake/shard.py b/disnake/shard.py index 102c66e4ae..a82ae13efd 100644 --- a/disnake/shard.py +++ b/disnake/shard.py @@ -589,7 +589,8 @@ async def change_presence( activities = () if activity is None else (activity,) for guild in guilds: me = guild.me - if me is None: + if me is None: # pyright: ignore[reportUnnecessaryComparison] + # may happen if guild is unavailable continue # Member.activities is typehinted as Tuple[ActivityType, ...], we may be setting it as Tuple[BaseActivity, ...] diff --git a/disnake/state.py b/disnake/state.py index ca915aa33f..714a92759b 100644 --- a/disnake/state.py +++ b/disnake/state.py @@ -25,7 +25,6 @@ Tuple, TypeVar, Union, - cast, overload, ) @@ -600,7 +599,6 @@ def _get_guild_channel( if channel is None: if "author" in data: # MessagePayload - data = cast("MessagePayload", data) user_id = int(data["author"]["id"]) else: # TypingStartEvent @@ -637,8 +635,6 @@ async def query_members( ): guild_id = guild.id ws = self._get_websocket(guild_id) - if ws is None: - raise RuntimeError("Somehow do not have a websocket for this guild_id") request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request @@ -1796,6 +1792,8 @@ def parse_voice_server_update(self, data: gateway.VoiceServerUpdateEvent) -> Non logging_coroutine(coro, info="Voice Protocol voice server update handler") ) + # FIXME: this should be refactored. The `GroupChannel` path will never be hit, + # `raw.timestamp` exists so no need to parse it twice, and `.get_user` should be used before falling back def parse_typing_start(self, data: gateway.TypingStartEvent) -> None: channel, guild = self._get_guild_channel(data) raw = RawTypingEvent(data) @@ -1810,7 +1808,7 @@ def parse_typing_start(self, data: gateway.TypingStartEvent) -> None: self.dispatch("raw_typing", raw) - if channel is not None: + if channel is not None: # pyright: ignore[reportUnnecessaryComparison] member = None if raw.member is not None: member = raw.member diff --git a/disnake/types/audit_log.py b/disnake/types/audit_log.py index d3b3a5484f..f9640b3ad9 100644 --- a/disnake/types/audit_log.py +++ b/disnake/types/audit_log.py @@ -103,8 +103,8 @@ class _AuditLogChange_Str(TypedDict): "permissions", "tags", ] - new_value: str - old_value: str + new_value: NotRequired[str] + old_value: NotRequired[str] class _AuditLogChange_AssetHash(TypedDict): @@ -116,8 +116,8 @@ class _AuditLogChange_AssetHash(TypedDict): "avatar_hash", "asset", ] - new_value: str - old_value: str + new_value: NotRequired[str] + old_value: NotRequired[str] class _AuditLogChange_Snowflake(TypedDict): @@ -134,8 +134,8 @@ class _AuditLogChange_Snowflake(TypedDict): "inviter_id", "guild_id", ] - new_value: Snowflake - old_value: Snowflake + new_value: NotRequired[Snowflake] + old_value: NotRequired[Snowflake] class _AuditLogChange_Bool(TypedDict): @@ -157,8 +157,8 @@ class _AuditLogChange_Bool(TypedDict): "premium_progress_bar_enabled", "enabled", ] - new_value: bool - old_value: bool + new_value: NotRequired[bool] + old_value: NotRequired[bool] class _AuditLogChange_Int(TypedDict): @@ -175,104 +175,104 @@ class _AuditLogChange_Int(TypedDict): "auto_archive_duration", "default_auto_archive_duration", ] - new_value: int - old_value: int + new_value: NotRequired[int] + old_value: NotRequired[int] class _AuditLogChange_ListSnowflake(TypedDict): key: Literal["exempt_roles", "exempt_channels"] - new_value: List[Snowflake] - old_value: List[Snowflake] + new_value: NotRequired[List[Snowflake]] + old_value: NotRequired[List[Snowflake]] class _AuditLogChange_ListRole(TypedDict): key: Literal["$add", "$remove"] - new_value: List[Role] - old_value: List[Role] + new_value: NotRequired[List[Role]] + old_value: NotRequired[List[Role]] class _AuditLogChange_MFALevel(TypedDict): key: Literal["mfa_level"] - new_value: MFALevel - old_value: MFALevel + new_value: NotRequired[MFALevel] + old_value: NotRequired[MFALevel] class _AuditLogChange_VerificationLevel(TypedDict): key: Literal["verification_level"] - new_value: VerificationLevel - old_value: VerificationLevel + new_value: NotRequired[VerificationLevel] + old_value: NotRequired[VerificationLevel] class _AuditLogChange_ExplicitContentFilter(TypedDict): key: Literal["explicit_content_filter"] - new_value: ExplicitContentFilterLevel - old_value: ExplicitContentFilterLevel + new_value: NotRequired[ExplicitContentFilterLevel] + old_value: NotRequired[ExplicitContentFilterLevel] class _AuditLogChange_DefaultMessageNotificationLevel(TypedDict): key: Literal["default_message_notifications"] - new_value: DefaultMessageNotificationLevel - old_value: DefaultMessageNotificationLevel + new_value: NotRequired[DefaultMessageNotificationLevel] + old_value: NotRequired[DefaultMessageNotificationLevel] class _AuditLogChange_ChannelType(TypedDict): key: Literal["type"] - new_value: ChannelType - old_value: ChannelType + new_value: NotRequired[ChannelType] + old_value: NotRequired[ChannelType] class _AuditLogChange_IntegrationExpireBehaviour(TypedDict): key: Literal["expire_behavior"] - new_value: IntegrationExpireBehavior - old_value: IntegrationExpireBehavior + new_value: NotRequired[IntegrationExpireBehavior] + old_value: NotRequired[IntegrationExpireBehavior] class _AuditLogChange_VideoQualityMode(TypedDict): key: Literal["video_quality_mode"] - new_value: VideoQualityMode - old_value: VideoQualityMode + new_value: NotRequired[VideoQualityMode] + old_value: NotRequired[VideoQualityMode] class _AuditLogChange_Overwrites(TypedDict): key: Literal["permission_overwrites"] - new_value: List[PermissionOverwrite] - old_value: List[PermissionOverwrite] + new_value: NotRequired[List[PermissionOverwrite]] + old_value: NotRequired[List[PermissionOverwrite]] class _AuditLogChange_Datetime(TypedDict): key: Literal["communication_disabled_until"] - new_value: datetime.datetime - old_value: datetime.datetime + new_value: NotRequired[datetime.datetime] + old_value: NotRequired[datetime.datetime] class _AuditLogChange_ApplicationCommandPermissions(TypedDict): key: str - new_value: ApplicationCommandPermissions - old_value: ApplicationCommandPermissions + new_value: NotRequired[ApplicationCommandPermissions] + old_value: NotRequired[ApplicationCommandPermissions] class _AuditLogChange_AutoModTriggerType(TypedDict): key: Literal["trigger_type"] - new_value: AutoModTriggerType - old_value: AutoModTriggerType + new_value: NotRequired[AutoModTriggerType] + old_value: NotRequired[AutoModTriggerType] class _AuditLogChange_AutoModEventType(TypedDict): key: Literal["event_type"] - new_value: AutoModEventType - old_value: AutoModEventType + new_value: NotRequired[AutoModEventType] + old_value: NotRequired[AutoModEventType] class _AuditLogChange_AutoModActions(TypedDict): key: Literal["actions"] - new_value: List[AutoModAction] - old_value: List[AutoModAction] + new_value: NotRequired[List[AutoModAction]] + old_value: NotRequired[List[AutoModAction]] class _AuditLogChange_AutoModTriggerMetadata(TypedDict): key: Literal["trigger_metadata"] - new_value: AutoModTriggerMetadata - old_value: AutoModTriggerMetadata + new_value: NotRequired[AutoModTriggerMetadata] + old_value: NotRequired[AutoModTriggerMetadata] AuditLogChange = Union[ diff --git a/disnake/types/automod.py b/disnake/types/automod.py index 156952d092..f7ac372e5e 100644 --- a/disnake/types/automod.py +++ b/disnake/types/automod.py @@ -8,9 +8,9 @@ from .snowflake import Snowflake, SnowflakeList -AutoModTriggerType = Literal[1, 2, 3, 4, 5] +AutoModTriggerType = Literal[1, 3, 4, 5] AutoModEventType = Literal[1] -AutoModActionType = Literal[1, 2] +AutoModActionType = Literal[1, 2, 3] AutoModPresetType = Literal[1, 2, 3] diff --git a/disnake/types/template.py b/disnake/types/template.py index ddb2c26cb7..e0008659aa 100644 --- a/disnake/types/template.py +++ b/disnake/types/template.py @@ -20,7 +20,7 @@ class Template(TypedDict): description: Optional[str] usage_count: int creator_id: Snowflake - creator: User + creator: Optional[User] # unsure when this can be null, but the spec says so created_at: str updated_at: str source_guild_id: Snowflake diff --git a/disnake/ui/action_row.py b/disnake/ui/action_row.py index fe7244a776..21ea01cb74 100644 --- a/disnake/ui/action_row.py +++ b/disnake/ui/action_row.py @@ -159,7 +159,8 @@ def __init__(self: ActionRow[ModalUIComponent], *components: ModalUIComponent) - def __init__(self: ActionRow[StrictUIComponentT], *components: StrictUIComponentT) -> None: ... - def __init__(self, *components: UIComponentT) -> None: + # n.b. this should be `*components: UIComponentT`, but pyright does not like it + def __init__(self, *components: Union[MessageUIComponent, ModalUIComponent]) -> None: self._children: List[UIComponentT] = [] for component in components: @@ -167,7 +168,7 @@ def __init__(self, *components: UIComponentT) -> None: raise TypeError( f"components should be of type WrappedComponent, got {type(component).__name__}." ) - self.append_item(component) + self.append_item(component) # type: ignore def __repr__(self) -> str: return f"" diff --git a/disnake/ui/button.py b/disnake/ui/button.py index d5e1fc7708..a961ba29ab 100644 --- a/disnake/ui/button.py +++ b/disnake/ui/button.py @@ -275,7 +275,7 @@ def button( def button( - cls: Type[Object[B_co, P]] = Button[Any], **kwargs: Any + cls: Type[Object[B_co, ...]] = Button[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[B_co]], DecoratedItem[B_co]]: """A decorator that attaches a button to a component. diff --git a/disnake/ui/item.py b/disnake/ui/item.py index 971ca8dcb3..464eb4d588 100644 --- a/disnake/ui/item.py +++ b/disnake/ui/item.py @@ -184,5 +184,5 @@ class Object(Protocol[T_co, P]): def __new__(cls) -> T_co: ... - def __init__(*args: P.args, **kwargs: P.kwargs) -> None: + def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: ... diff --git a/disnake/ui/modal.py b/disnake/ui/modal.py index a7a5503a28..adf21ffa9c 100644 --- a/disnake/ui/modal.py +++ b/disnake/ui/modal.py @@ -55,7 +55,7 @@ def __init__( custom_id: str = MISSING, timeout: float = 600, ) -> None: - if timeout is None: + if timeout is None: # pyright: ignore[reportUnnecessaryComparison] raise ValueError("Timeout may not be None") rows = components_to_rows(components) diff --git a/disnake/ui/select/channel.py b/disnake/ui/select/channel.py index a98472b547..57dd9cfbe9 100644 --- a/disnake/ui/select/channel.py +++ b/disnake/ui/select/channel.py @@ -168,7 +168,7 @@ def channel_select( def channel_select( - cls: Type[Object[S_co, P]] = ChannelSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = ChannelSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a channel select menu to a component. diff --git a/disnake/ui/select/mentionable.py b/disnake/ui/select/mentionable.py index 4f0d591201..860903f7f1 100644 --- a/disnake/ui/select/mentionable.py +++ b/disnake/ui/select/mentionable.py @@ -144,7 +144,7 @@ def mentionable_select( def mentionable_select( - cls: Type[Object[S_co, P]] = MentionableSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = MentionableSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a mentionable (user/member/role) select menu to a component. diff --git a/disnake/ui/select/role.py b/disnake/ui/select/role.py index 69b1bcaa57..fe2da2f97a 100644 --- a/disnake/ui/select/role.py +++ b/disnake/ui/select/role.py @@ -142,7 +142,7 @@ def role_select( def role_select( - cls: Type[Object[S_co, P]] = RoleSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = RoleSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a role select menu to a component. diff --git a/disnake/ui/select/string.py b/disnake/ui/select/string.py index d38c9ea6ba..3eeedc1f22 100644 --- a/disnake/ui/select/string.py +++ b/disnake/ui/select/string.py @@ -268,7 +268,7 @@ def string_select( def string_select( - cls: Type[Object[S_co, P]] = StringSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = StringSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a string select menu to a component. diff --git a/disnake/ui/select/user.py b/disnake/ui/select/user.py index 179b9d6c74..4868894a83 100644 --- a/disnake/ui/select/user.py +++ b/disnake/ui/select/user.py @@ -143,7 +143,7 @@ def user_select( def user_select( - cls: Type[Object[S_co, P]] = UserSelect[Any], **kwargs: Any + cls: Type[Object[S_co, ...]] = UserSelect[Any], **kwargs: Any ) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: """A decorator that attaches a user select menu to a component. diff --git a/disnake/utils.py b/disnake/utils.py index a74d50ab94..6fa6ae82d7 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -134,6 +134,7 @@ class _RequestLike(Protocol): V = TypeVar("V") T_co = TypeVar("T_co", covariant=True) _Iter = Union[Iterator[T], AsyncIterator[T]] +_BytesLike = Union[bytes, bytearray, memoryview] class CachedSlotProperty(Generic[T, T_co]): @@ -489,7 +490,7 @@ def _maybe_cast(value: V, converter: Callable[[V], T], default: T = None) -> Opt } -def _get_mime_type_for_image(data: bytes) -> str: +def _get_mime_type_for_image(data: _BytesLike) -> str: if data[0:8] == b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A": return "image/png" elif data[0:3] == b"\xff\xd8\xff" or data[6:10] in (b"JFIF", b"Exif"): @@ -502,14 +503,14 @@ def _get_mime_type_for_image(data: bytes) -> str: raise ValueError("Unsupported image type given") -def _bytes_to_base64_data(data: bytes) -> str: +def _bytes_to_base64_data(data: _BytesLike) -> str: fmt = "data:{mime};base64,{data}" mime = _get_mime_type_for_image(data) b64 = b64encode(data).decode("ascii") return fmt.format(mime=mime, data=b64) -def _get_extension_for_image(data: bytes) -> Optional[str]: +def _get_extension_for_image(data: _BytesLike) -> Optional[str]: try: mime_type = _get_mime_type_for_image(data) except ValueError: @@ -538,7 +539,7 @@ async def _assetbytes_to_base64_data(data: Optional[AssetBytes]) -> Optional[str if HAS_ORJSON: def _to_json(obj: Any) -> str: - return orjson.dumps(obj).decode("utf-8") + return orjson.dumps(obj).decode("utf-8") # type: ignore _from_json = orjson.loads # type: ignore @@ -571,7 +572,8 @@ async def maybe_coroutine( return value # type: ignore # typeguard doesn't narrow in the negative case -async def async_all(gen: Iterable[Union[Awaitable[bool], bool]], *, check=_isawaitable) -> bool: +async def async_all(gen: Iterable[Union[Awaitable[bool], bool]]) -> bool: + check = _isawaitable for elem in gen: if check(elem): elem = await elem diff --git a/disnake/voice_client.py b/disnake/voice_client.py index 52750ecebd..a6cc13e0ba 100644 --- a/disnake/voice_client.py +++ b/disnake/voice_client.py @@ -279,7 +279,7 @@ async def on_voice_server_update(self, data: VoiceServerUpdateEvent) -> None: self.server_id = int(data["guild_id"]) endpoint = data.get("endpoint") - if endpoint is None or self.token is None: + if endpoint is None or not self.token: _log.warning( "Awaiting endpoint... This requires waiting. " "If timeout occurred considering raising the timeout and reconnecting." diff --git a/docs/extensions/builder.py b/docs/extensions/builder.py index 5133af0f85..61e366d2ca 100644 --- a/docs/extensions/builder.py +++ b/docs/extensions/builder.py @@ -65,7 +65,7 @@ def disable_mathjax(app: Sphinx, config: Config) -> None: # inspired by https://github.com/readthedocs/sphinx-hoverxref/blob/003b84fee48262f1a969c8143e63c177bd98aa26/hoverxref/extension.py#L151 for listener in app.events.listeners.get("html-page-context", []): - module_name = inspect.getmodule(listener.handler).__name__ # type: ignore + module_name = inspect.getmodule(listener.handler).__name__ if module_name == "sphinx.ext.mathjax": app.disconnect(listener.id) diff --git a/examples/basic_voice.py b/examples/basic_voice.py index 6d224b21e5..45046c780f 100644 --- a/examples/basic_voice.py +++ b/examples/basic_voice.py @@ -33,8 +33,6 @@ "source_address": "0.0.0.0", # bind to ipv4 since ipv6 addresses cause issues sometimes } -ffmpeg_options = {"options": "-vn"} - ytdl = youtube_dl.YoutubeDL(ytdl_format_options) @@ -59,7 +57,7 @@ async def from_url( filename = data["url"] if stream else ytdl.prepare_filename(data) - return cls(disnake.FFmpegPCMAudio(filename, **ffmpeg_options), data=data) + return cls(disnake.FFmpegPCMAudio(filename, options="-vn"), data=data) class Music(commands.Cog): diff --git a/examples/interactions/injections.py b/examples/interactions/injections.py index 27576d60bc..30c7554dd6 100644 --- a/examples/interactions/injections.py +++ b/examples/interactions/injections.py @@ -114,7 +114,7 @@ async def get_game_user( if user is None: return await db.get_game_user(id=inter.author.id) - game_user: GameUser = await db.search_game_user(username=user, server=server) + game_user: Optional[GameUser] = await db.search_game_user(username=user, server=server) if game_user is None: raise commands.CommandError(f"User with username {user!r} could not be found") diff --git a/examples/interactions/modal.py b/examples/interactions/modal.py index 00b2364789..f271c82f4c 100644 --- a/examples/interactions/modal.py +++ b/examples/interactions/modal.py @@ -2,6 +2,8 @@ """An example demonstrating two methods of sending modals and handling modal responses.""" +# pyright: reportUnknownLambdaType=false + import asyncio import os diff --git a/pyproject.toml b/pyproject.toml index 4756d55c4a..73f3ad1b9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ codemod = [ ] typing = [ # this is not pyright itself, but the python wrapper - "pyright==1.1.291", + "pyright==1.1.336", "typing-extensions~=4.8.0", # only used for type-checking, version does not matter "pytz", diff --git a/test_bot/cogs/modals.py b/test_bot/cogs/modals.py index 13c84bddf2..c5d514a25c 100644 --- a/test_bot/cogs/modals.py +++ b/test_bot/cogs/modals.py @@ -65,7 +65,7 @@ async def create_tag_low(self, inter: disnake.AppCmdInter[commands.Bot]) -> None modal_inter: disnake.ModalInteraction = await self.bot.wait_for( "modal_submit", - check=lambda i: i.custom_id == "create_tag2" and i.author.id == inter.author.id, + check=lambda i: i.custom_id == "create_tag2" and i.author.id == inter.author.id, # type: ignore # unknown parameter type ) embed = disnake.Embed(title="Tag Creation") diff --git a/tests/ui/test_decorators.py b/tests/ui/test_decorators.py index 5fab1bb787..e9c3680873 100644 --- a/tests/ui/test_decorators.py +++ b/tests/ui/test_decorators.py @@ -30,16 +30,16 @@ def __init__(self, *, param: float = 42.0) -> None: class TestDecorator: def test_default(self) -> None: - with create_callback(ui.Button) as func: + with create_callback(ui.Button[ui.View]) as func: res = ui.button(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.Button]) + assert_type(res, ui.item.DecoratedItem[ui.Button[ui.View]]) assert func.__discord_ui_model_type__ is ui.Button assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"} - with create_callback(ui.StringSelect) as func: + with create_callback(ui.StringSelect[ui.View]) as func: res = ui.string_select(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.StringSelect]) + assert_type(res, ui.item.DecoratedItem[ui.StringSelect[ui.View]]) assert func.__discord_ui_model_type__ is ui.StringSelect assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"} From cd48c9259d7286cec6163e50364d99a14e75a757 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Fri, 17 Nov 2023 20:39:17 +0100 Subject: [PATCH 3/7] feat: add `TeamMember.role` (#1094) --- changelog/1094.doc.rst | 1 + changelog/1094.feature.0.rst | 1 + changelog/1094.feature.1.rst | 1 + disnake/enums.py | 10 ++++++++++ disnake/ext/commands/bot.py | 4 ++-- disnake/ext/commands/common_bot_base.py | 13 +++++++++++-- disnake/team.py | 23 +++++++++++------------ disnake/types/team.py | 5 +++-- docs/api/app_info.rst | 22 ++++++++++++++++++++++ 9 files changed, 62 insertions(+), 18 deletions(-) create mode 100644 changelog/1094.doc.rst create mode 100644 changelog/1094.feature.0.rst create mode 100644 changelog/1094.feature.1.rst diff --git a/changelog/1094.doc.rst b/changelog/1094.doc.rst new file mode 100644 index 0000000000..13fe750a4b --- /dev/null +++ b/changelog/1094.doc.rst @@ -0,0 +1 @@ +Add inherited attributes to :class:`TeamMember`, and fix :attr:`TeamMember.avatar` documentation. diff --git a/changelog/1094.feature.0.rst b/changelog/1094.feature.0.rst new file mode 100644 index 0000000000..dcb2dcf367 --- /dev/null +++ b/changelog/1094.feature.0.rst @@ -0,0 +1 @@ +Add :attr:`TeamMember.role`. diff --git a/changelog/1094.feature.1.rst b/changelog/1094.feature.1.rst new file mode 100644 index 0000000000..1b7eb2cd33 --- /dev/null +++ b/changelog/1094.feature.1.rst @@ -0,0 +1 @@ +|commands| Update :meth:`Bot.is_owner ` to take team member roles into account. diff --git a/disnake/enums.py b/disnake/enums.py index cb603c5425..912cb36183 100644 --- a/disnake/enums.py +++ b/disnake/enums.py @@ -35,6 +35,7 @@ "ActivityType", "NotificationLevel", "TeamMembershipState", + "TeamMemberRole", "WebhookType", "ExpireBehaviour", "ExpireBehavior", @@ -551,6 +552,15 @@ class TeamMembershipState(Enum): accepted = 2 +class TeamMemberRole(Enum): + admin = "admin" + developer = "developer" + read_only = "read_only" + + def __str__(self) -> str: + return self.name + + class WebhookType(Enum): incoming = 1 channel_follower = 2 diff --git a/disnake/ext/commands/bot.py b/disnake/ext/commands/bot.py index 5c3ba59eac..825f96e6ae 100644 --- a/disnake/ext/commands/bot.py +++ b/disnake/ext/commands/bot.py @@ -184,7 +184,7 @@ class Bot(BotBase, InteractionBotBase, disnake.Client): owner_ids: Optional[Collection[:class:`int`]] The IDs of the users that own the bot. This is similar to :attr:`owner_id`. If this is not set and the application is team based, then it is - fetched automatically using :meth:`~.Bot.application_info`. + fetched automatically using :meth:`~.Bot.application_info` (taking team roles into account). For performance reasons it is recommended to use a :class:`set` for the collection. You cannot set both ``owner_id`` and ``owner_ids``. @@ -403,7 +403,7 @@ class InteractionBot(InteractionBotBase, disnake.Client): owner_ids: Optional[Collection[:class:`int`]] The IDs of the users that own the bot. This is similar to :attr:`owner_id`. If this is not set and the application is team based, then it is - fetched automatically using :meth:`~.Bot.application_info`. + fetched automatically using :meth:`~.Bot.application_info` (taking team roles into account). For performance reasons it is recommended to use a :class:`set` for the collection. You cannot set both ``owner_id`` and ``owner_ids``. diff --git a/disnake/ext/commands/common_bot_base.py b/disnake/ext/commands/common_bot_base.py index f0d8fd5566..737736c170 100644 --- a/disnake/ext/commands/common_bot_base.py +++ b/disnake/ext/commands/common_bot_base.py @@ -81,8 +81,13 @@ async def _fill_owners(self) -> None: app: disnake.AppInfo = await self.application_info() # type: ignore if app.team: - self.owners = set(app.team.members) - self.owner_ids = {m.id for m in app.team.members} + self.owners = owners = { + member + for member in app.team.members + # these roles can access the bot token, consider them bot owners + if member.role in (disnake.TeamMemberRole.admin, disnake.TeamMemberRole.developer) + } + self.owner_ids = {m.id for m in owners} else: self.owner = app.owner self.owner_id = app.owner.id @@ -130,6 +135,10 @@ async def is_owner(self, user: Union[disnake.User, disnake.Member]) -> bool: The function also checks if the application is team-owned if :attr:`owner_ids` is not set. + .. versionchanged:: 2.10 + Also takes team roles into account; only team members with the :attr:`~disnake.TeamMemberRole.admin` + or :attr:`~disnake.TeamMemberRole.developer` roles are considered bot owners. + Parameters ---------- user: :class:`.abc.User` diff --git a/disnake/team.py b/disnake/team.py index dd0ee48d76..a1f126304e 100644 --- a/disnake/team.py +++ b/disnake/team.py @@ -7,7 +7,7 @@ from . import utils from .asset import Asset -from .enums import TeamMembershipState, try_enum +from .enums import TeamMemberRole, TeamMembershipState, try_enum from .user import BaseUser if TYPE_CHECKING: @@ -21,7 +21,8 @@ class Team: - """Represents an application team for a bot provided by Discord. + """Represents an application team. + Teams are groups of users who share access to an application's configuration. Attributes ---------- @@ -30,7 +31,7 @@ class Team: name: :class:`str` The team name. owner_id: :class:`int` - The team's owner ID. + The team owner's ID. members: List[:class:`TeamMember`] A list of the members in the team. @@ -44,7 +45,7 @@ def __init__(self, state: ConnectionState, data: TeamPayload) -> None: self.id: int = int(data["id"]) self.name: str = data["name"] - self._icon: Optional[str] = data["icon"] + self._icon: Optional[str] = data.get("icon") self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id") self.members: List[TeamMember] = [ TeamMember(self, self._state, member) for member in data["members"] @@ -113,29 +114,27 @@ class TeamMember(BaseUser): See the `help article `__ for details. global_name: Optional[:class:`str`] - The team members's global display name, if set. + The team member's global display name, if set. This takes precedence over :attr:`.name` when shown. .. versionadded:: 2.9 - avatar: Optional[:class:`str`] - The avatar hash the team member has. Could be None. - bot: :class:`bool` - Specifies if the user is a bot account. team: :class:`Team` The team that the member is from. membership_state: :class:`TeamMembershipState` - The membership state of the member (e.g. invited or accepted) + The membership state of the member (e.g. invited or accepted). + role: :class:`TeamMemberRole` + The role of the team member in the team. """ - __slots__ = ("team", "membership_state", "permissions") + __slots__ = ("team", "membership_state", "role") def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload) -> None: self.team: Team = team self.membership_state: TeamMembershipState = try_enum( TeamMembershipState, data["membership_state"] ) - self.permissions: List[str] = data["permissions"] + self.role: TeamMemberRole = try_enum(TeamMemberRole, data.get("role")) super().__init__(state=state, data=data["user"]) def __repr__(self) -> str: diff --git a/disnake/types/team.py b/disnake/types/team.py index 5662365e03..0829c18b5c 100644 --- a/disnake/types/team.py +++ b/disnake/types/team.py @@ -8,13 +8,14 @@ from .user import PartialUser TeamMembershipState = Literal[1, 2] +TeamMemberRole = Literal["admin", "developer", "read_only"] class TeamMember(TypedDict): - user: PartialUser membership_state: TeamMembershipState - permissions: List[str] team_id: Snowflake + user: PartialUser + role: TeamMemberRole class Team(TypedDict): diff --git a/docs/api/app_info.rst b/docs/api/app_info.rst index eed2a74143..7aa6c4148f 100644 --- a/docs/api/app_info.rst +++ b/docs/api/app_info.rst @@ -49,6 +49,7 @@ TeamMember .. autoclass:: TeamMember() :members: + :inherited-members: Data Classes ------------ @@ -89,6 +90,27 @@ TeamMembershipState Represents a member currently in the team. +TeamMemberRole +~~~~~~~~~~~~~~ + +.. class:: TeamMemberRole + + Represents the role of a team member retrieved through :func:`Client.application_info`. + + .. versionadded:: 2.10 + + .. attribute:: admin + + Admins have the most permissions. An admin can only take destructive actions on the team or team-owned apps if they are the team owner. + + .. attribute:: developer + + Developers can access information about a team and team-owned applications, and take limited actions on them, like configuring interaction endpoints or resetting the bot token. + + .. attribute:: read_only + + Read-only members can access information about a team and team-owned applications. + ApplicationRoleConnectionMetadataType ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From fb73dbbc4731eb19d653e71fe48f300556f6cd96 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:30:58 +0100 Subject: [PATCH 4/7] feat(permissions): add `create_guild_expressions` and `create_events` permissions (#1091) --- changelog/1091.feature.rst | 1 + disnake/abc.py | 2 ++ disnake/channel.py | 1 + disnake/emoji.py | 4 +-- disnake/ext/commands/base_core.py | 2 ++ disnake/ext/commands/core.py | 8 +++++ disnake/guild.py | 6 ++-- disnake/guild_scheduled_event.py | 10 +++--- disnake/permissions.py | 56 +++++++++++++++++++++++++++++-- disnake/sticker.py | 4 +-- 10 files changed, 79 insertions(+), 15 deletions(-) create mode 100644 changelog/1091.feature.rst diff --git a/changelog/1091.feature.rst b/changelog/1091.feature.rst new file mode 100644 index 0000000000..6ae43e7e5f --- /dev/null +++ b/changelog/1091.feature.rst @@ -0,0 +1 @@ +Add :attr:`Permissions.create_guild_expressions` and :attr:`Permissions.create_events`. diff --git a/disnake/abc.py b/disnake/abc.py index b9a60f3ee5..a6f84163ac 100644 --- a/disnake/abc.py +++ b/disnake/abc.py @@ -853,7 +853,9 @@ async def set_permissions( ban_members: Optional[bool] = ..., change_nickname: Optional[bool] = ..., connect: Optional[bool] = ..., + create_events: Optional[bool] = ..., create_forum_threads: Optional[bool] = ..., + create_guild_expressions: Optional[bool] = ..., create_instant_invite: Optional[bool] = ..., create_private_threads: Optional[bool] = ..., create_public_threads: Optional[bool] = ..., diff --git a/disnake/channel.py b/disnake/channel.py index ffb11f2d2c..00598927ca 100644 --- a/disnake/channel.py +++ b/disnake/channel.py @@ -1207,6 +1207,7 @@ def permissions_for( denied.update( manage_channels=True, manage_roles=True, + create_events=True, manage_events=True, manage_webhooks=True, ) diff --git a/disnake/emoji.py b/disnake/emoji.py index 2a24877b07..badedbce86 100644 --- a/disnake/emoji.py +++ b/disnake/emoji.py @@ -192,7 +192,7 @@ async def delete(self, *, reason: Optional[str] = None) -> None: Raises ------ Forbidden - You are not allowed to delete emojis. + You are not allowed to delete this emoji. HTTPException An error occurred deleting the emoji. """ @@ -227,7 +227,7 @@ async def edit( Raises ------ Forbidden - You are not allowed to edit emojis. + You are not allowed to edit this emoji. HTTPException An error occurred editing the emoji. diff --git a/disnake/ext/commands/base_core.py b/disnake/ext/commands/base_core.py index 3599ea0908..b5e0498399 100644 --- a/disnake/ext/commands/base_core.py +++ b/disnake/ext/commands/base_core.py @@ -636,7 +636,9 @@ def default_member_permissions( ban_members: bool = ..., change_nickname: bool = ..., connect: bool = ..., + create_events: bool = ..., create_forum_threads: bool = ..., + create_guild_expressions: bool = ..., create_instant_invite: bool = ..., create_private_threads: bool = ..., create_public_threads: bool = ..., diff --git a/disnake/ext/commands/core.py b/disnake/ext/commands/core.py index 2ddcb10075..b9d49ab269 100644 --- a/disnake/ext/commands/core.py +++ b/disnake/ext/commands/core.py @@ -1999,7 +1999,9 @@ def has_permissions( ban_members: bool = ..., change_nickname: bool = ..., connect: bool = ..., + create_events: bool = ..., create_forum_threads: bool = ..., + create_guild_expressions: bool = ..., create_instant_invite: bool = ..., create_private_threads: bool = ..., create_public_threads: bool = ..., @@ -2121,7 +2123,9 @@ def bot_has_permissions( ban_members: bool = ..., change_nickname: bool = ..., connect: bool = ..., + create_events: bool = ..., create_forum_threads: bool = ..., + create_guild_expressions: bool = ..., create_instant_invite: bool = ..., create_private_threads: bool = ..., create_public_threads: bool = ..., @@ -2221,7 +2225,9 @@ def has_guild_permissions( ban_members: bool = ..., change_nickname: bool = ..., connect: bool = ..., + create_events: bool = ..., create_forum_threads: bool = ..., + create_guild_expressions: bool = ..., create_instant_invite: bool = ..., create_private_threads: bool = ..., create_public_threads: bool = ..., @@ -2318,7 +2324,9 @@ def bot_has_guild_permissions( ban_members: bool = ..., change_nickname: bool = ..., connect: bool = ..., + create_events: bool = ..., create_forum_threads: bool = ..., + create_guild_expressions: bool = ..., create_instant_invite: bool = ..., create_private_threads: bool = ..., create_public_threads: bool = ..., diff --git a/disnake/guild.py b/disnake/guild.py index ba140f2298..449f303c27 100644 --- a/disnake/guild.py +++ b/disnake/guild.py @@ -2387,7 +2387,7 @@ async def create_scheduled_event( Creates a :class:`GuildScheduledEvent`. - You must have :attr:`.Permissions.manage_events` permission to do this. + You must have :attr:`~Permissions.manage_events` permission to do this. Based on the channel/entity type, there are different restrictions regarding other parameter values, as shown in this table: @@ -3274,7 +3274,7 @@ async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = No Raises ------ Forbidden - You are not allowed to delete stickers. + You are not allowed to delete this sticker. HTTPException An error occurred deleting the sticker. """ @@ -3429,7 +3429,7 @@ async def delete_emoji(self, emoji: Snowflake, *, reason: Optional[str] = None) Raises ------ Forbidden - You are not allowed to delete emojis. + You are not allowed to delete this emoji. HTTPException An error occurred deleting the emoji. """ diff --git a/disnake/guild_scheduled_event.py b/disnake/guild_scheduled_event.py index 63b23620fe..1b01be136c 100644 --- a/disnake/guild_scheduled_event.py +++ b/disnake/guild_scheduled_event.py @@ -253,7 +253,7 @@ async def delete(self) -> None: Deletes the guild scheduled event. - You must have :attr:`.Permissions.manage_events` permission to do this. + You must have :attr:`~Permissions.manage_events` permission to do this. Raises ------ @@ -382,7 +382,7 @@ async def edit( Edits the guild scheduled event. - You must have :attr:`.Permissions.manage_events` permission to do this. + You must have :attr:`~Permissions.manage_events` permission to do this. .. versionchanged:: 2.6 Updates must follow requirements of :func:`Guild.create_scheduled_event` @@ -536,7 +536,7 @@ async def start(self, *, reason: Optional[str] = None) -> GuildScheduledEvent: Changes the event status to :attr:`~GuildScheduledEventStatus.active`. - You must have :attr:`.Permissions.manage_events` permission to do this. + You must have :attr:`~Permissions.manage_events` permission to do this. .. versionadded:: 2.7 @@ -570,7 +570,7 @@ async def end(self, *, reason: Optional[str] = None) -> GuildScheduledEvent: Changes the event status to :attr:`~GuildScheduledEventStatus.completed`. - You must have :attr:`.Permissions.manage_events` permission to do this. + You must have :attr:`~Permissions.manage_events` permission to do this. .. versionadded:: 2.7 @@ -604,7 +604,7 @@ async def cancel(self, *, reason: Optional[str] = None) -> GuildScheduledEvent: Changes the event status to :attr:`~GuildScheduledEventStatus.cancelled`. - You must have :attr:`.Permissions.manage_events` permission to do this. + You must have :attr:`~Permissions.manage_events` permission to do this. .. versionadded:: 2.7 diff --git a/disnake/permissions.py b/disnake/permissions.py index a7df815caa..edad50e84d 100644 --- a/disnake/permissions.py +++ b/disnake/permissions.py @@ -164,7 +164,9 @@ def __init__( ban_members: bool = ..., change_nickname: bool = ..., connect: bool = ..., + create_events: bool = ..., create_forum_threads: bool = ..., + create_guild_expressions: bool = ..., create_instant_invite: bool = ..., create_private_threads: bool = ..., create_public_threads: bool = ..., @@ -291,6 +293,7 @@ def all_channel(cls) -> Self: ``True`` and the guild-specific ones set to ``False``. The guild-specific permissions are currently: + - :attr:`create_guild_expressions` - :attr:`manage_guild_expressions` - :attr:`view_audit_log` - :attr:`view_guild_insights` @@ -316,12 +319,16 @@ def all_channel(cls) -> Self: .. versionchanged:: 2.9 Added :attr:`use_soundboard` and :attr:`send_voice_messages` permissions. + + .. versionchanged:: 2.10 + Added :attr:`create_events` permission. """ instance = cls.all() instance.update( administrator=False, ban_members=False, change_nickname=False, + create_guild_expressions=False, kick_members=False, manage_guild=False, manage_guild_expressions=False, @@ -347,11 +354,15 @@ def general(cls) -> Self: .. versionchanged:: 2.9 Added :attr:`view_creator_monetization_analytics` permission. + + .. versionchanged:: 2.10 + Added :attr:`create_guild_expressions` permission. """ return cls( view_channel=True, manage_channels=True, manage_roles=True, + create_guild_expressions=True, manage_guild_expressions=True, view_audit_log=True, view_guild_insights=True, @@ -475,8 +486,12 @@ def events(cls) -> Self: "Events" permissions from the official Discord UI set to ``True``. .. versionadded:: 2.4 + + .. versionchanged:: 2.10 + Added :attr:`create_events` permission. """ return cls( + create_events=True, manage_events=True, ) @@ -532,7 +547,9 @@ def update( ban_members: bool = ..., change_nickname: bool = ..., connect: bool = ..., + create_events: bool = ..., create_forum_threads: bool = ..., + create_guild_expressions: bool = ..., create_instant_invite: bool = ..., create_private_threads: bool = ..., create_public_threads: bool = ..., @@ -830,8 +847,10 @@ def manage_webhooks(self) -> int: @flag_value def manage_guild_expressions(self) -> int: - """:class:`bool`: Returns ``True`` if a user can create, edit, or delete - emojis, stickers, and soundboard sounds. + """:class:`bool`: Returns ``True`` if a user can edit or delete + emojis, stickers, and soundboard sounds created by all users. + + See also :attr:`create_guild_expressions`. .. versionadded:: 2.9 """ @@ -879,7 +898,10 @@ def request_to_speak(self) -> int: @flag_value def manage_events(self) -> int: - """:class:`bool`: Returns ``True`` if a user can manage guild events. + """:class:`bool`: Returns ``True`` if a user can edit or delete guild scheduled events + created by all users. + + See also :attr:`create_events`. .. versionadded:: 2.0 """ @@ -978,6 +1000,28 @@ def use_soundboard(self) -> int: """ return 1 << 42 + @flag_value + def create_guild_expressions(self) -> int: + """:class:`bool`: Returns ``True`` if a user can create emojis, stickers, + and soundboard sounds, as well as edit and delete the ones they created. + + See also :attr:`manage_guild_expressions`. + + .. versionadded:: 2.10 + """ + return 1 << 43 + + @flag_value + def create_events(self) -> int: + """:class:`bool`: Returns ``True`` if a user can create guild scheduled events, + as well as edit and delete the ones they created. + + See also :attr:`manage_events`. + + .. versionadded:: 2.10 + """ + return 1 << 44 + @flag_value def use_external_sounds(self) -> int: """:class:`bool`: Returns ``True`` if a user can use custom soundboard sounds from other guilds. @@ -1066,7 +1110,9 @@ class PermissionOverwrite: ban_members: Optional[bool] change_nickname: Optional[bool] connect: Optional[bool] + create_events: Optional[bool] create_forum_threads: Optional[bool] + create_guild_expressions: Optional[bool] create_instant_invite: Optional[bool] create_private_threads: Optional[bool] create_public_threads: Optional[bool] @@ -1130,7 +1176,9 @@ def __init__( ban_members: Optional[bool] = ..., change_nickname: Optional[bool] = ..., connect: Optional[bool] = ..., + create_events: Optional[bool] = ..., create_forum_threads: Optional[bool] = ..., + create_guild_expressions: Optional[bool] = ..., create_instant_invite: Optional[bool] = ..., create_private_threads: Optional[bool] = ..., create_public_threads: Optional[bool] = ..., @@ -1261,7 +1309,9 @@ def update( ban_members: Optional[bool] = ..., change_nickname: Optional[bool] = ..., connect: Optional[bool] = ..., + create_events: Optional[bool] = ..., create_forum_threads: Optional[bool] = ..., + create_guild_expressions: Optional[bool] = ..., create_instant_invite: Optional[bool] = ..., create_private_threads: Optional[bool] = ..., create_public_threads: Optional[bool] = ..., diff --git a/disnake/sticker.py b/disnake/sticker.py index 0d94c1ebcc..be7479cf2b 100644 --- a/disnake/sticker.py +++ b/disnake/sticker.py @@ -450,7 +450,7 @@ async def edit( Raises ------ Forbidden - You are not allowed to edit stickers. + You are not allowed to edit this sticker. HTTPException An error occurred editing the sticker. @@ -498,7 +498,7 @@ async def delete(self, *, reason: Optional[str] = None) -> None: Raises ------ Forbidden - You are not allowed to delete stickers. + You are not allowed to delete this sticker. HTTPException An error occurred deleting the sticker. """ From cc5db415a8d1988b482acc4e8808489333f727b0 Mon Sep 17 00:00:00 2001 From: Snipy7374 <100313469+Snipy7374@users.noreply.github.com> Date: Wed, 22 Nov 2023 14:52:39 +0100 Subject: [PATCH 5/7] docs: add docstrings for some properties of `InvokableSlashCommand` and `SubCommand` and document `Option`'s attributes (#1112) --- changelog/1112.doc.rst | 1 + disnake/app_commands.py | 34 ++++++++++++++++++++++++++++++ disnake/ext/commands/slash_core.py | 4 ++++ docs/api/app_commands.rst | 2 +- 4 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 changelog/1112.doc.rst diff --git a/changelog/1112.doc.rst b/changelog/1112.doc.rst new file mode 100644 index 0000000000..e510f78951 --- /dev/null +++ b/changelog/1112.doc.rst @@ -0,0 +1 @@ +Document the :class:`.Option` attributes, the ``description`` and ``options`` properties for :class:`.ext.commands.InvokableSlashCommand` and the ``description`` and ``body`` properties for :class:`.ext.commands.SubCommand`. diff --git a/disnake/app_commands.py b/disnake/app_commands.py index 17c3fd713a..a6a58fe188 100644 --- a/disnake/app_commands.py +++ b/disnake/app_commands.py @@ -198,6 +198,40 @@ class Option: .. versionadded:: 2.6 + max_length: :class:`int` + The maximum length for this option if this is a string option. + + .. versionadded:: 2.6 + + Attributes + ---------- + name: :class:`str` + The option's name. + description: :class:`str` + The option's description. + type: :class:`OptionType` + The option type, e.g. :class:`OptionType.user`. + required: :class:`bool` + Whether this option is required. + choices: List[:class:`OptionChoice`] + The list of option choices. + options: List[:class:`Option`] + The list of sub options. Normally you don't have to specify it directly, + instead consider using ``@main_cmd.sub_command`` or ``@main_cmd.sub_command_group`` decorators. + channel_types: List[:class:`ChannelType`] + The list of channel types that your option supports, if the type is :class:`OptionType.channel`. + By default, it supports all channel types. + autocomplete: :class:`bool` + Whether this option can be autocompleted. + min_value: Union[:class:`int`, :class:`float`] + The minimum value permitted. + max_value: Union[:class:`int`, :class:`float`] + The maximum value permitted. + min_length: :class:`int` + The minimum length for this option if this is a string option. + + .. versionadded:: 2.6 + max_length: :class:`int` The maximum length for this option if this is a string option. diff --git a/disnake/ext/commands/slash_core.py b/disnake/ext/commands/slash_core.py index 4652c552f8..a23cf86bd3 100644 --- a/disnake/ext/commands/slash_core.py +++ b/disnake/ext/commands/slash_core.py @@ -332,10 +332,12 @@ def parents( @property def description(self) -> str: + """:class:`str`: The slash sub command's description. Shorthand for :attr:`self.body.description <.Option.description>`.""" return self.body.description @property def body(self) -> Option: + """:class:`.Option`: The API representation for this slash sub command. Shorthand for :attr:`.SubCommand.option`""" return self.option async def _call_autocompleter( @@ -508,10 +510,12 @@ def _ensure_assignment_on_copy(self, other: SlashCommandT) -> SlashCommandT: @property def description(self) -> str: + """:class:`str`: The slash command's description. Shorthand for :attr:`self.body.description <.SlashCommand.description>`.""" return self.body.description @property def options(self) -> List[Option]: + """List[:class:`.Option`]: The list of options the slash command has. Shorthand for :attr:`self.body.options <.SlashCommand.options>`.""" return self.body.options def sub_command( diff --git a/docs/api/app_commands.rst b/docs/api/app_commands.rst index 1d26bee539..a55e3670ab 100644 --- a/docs/api/app_commands.rst +++ b/docs/api/app_commands.rst @@ -97,7 +97,7 @@ Option .. attributetable:: Option -.. autoclass:: Option() +.. autoclass:: Option :members: OptionChoice From 594c12ad5230522a12008445217e71726cc60611 Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Sat, 25 Nov 2023 23:36:53 +0100 Subject: [PATCH 6/7] fix(appcmd): use covariant collection types for `choices` (#1136) --- changelog/1136.bugfix.rst | 1 + disnake/app_commands.py | 25 ++++++++++++++----------- disnake/ext/commands/params.py | 11 ++++++----- disnake/interactions/base.py | 7 +++++-- 4 files changed, 26 insertions(+), 18 deletions(-) create mode 100644 changelog/1136.bugfix.rst diff --git a/changelog/1136.bugfix.rst b/changelog/1136.bugfix.rst new file mode 100644 index 0000000000..571cba6cbd --- /dev/null +++ b/changelog/1136.bugfix.rst @@ -0,0 +1 @@ +Update ``choices`` type in app commands to accept any :class:`~py:typing.Sequence` or :class:`~py:typing.Mapping`, instead of the more constrained :class:`list`/:class:`dict` types. diff --git a/disnake/app_commands.py b/disnake/app_commands.py index a6a58fe188..727f35cb93 100644 --- a/disnake/app_commands.py +++ b/disnake/app_commands.py @@ -5,7 +5,7 @@ import math import re from abc import ABC -from typing import TYPE_CHECKING, ClassVar, Dict, List, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, ClassVar, List, Mapping, Optional, Sequence, Tuple, Union from .enums import ( ApplicationCommandPermissionType, @@ -37,10 +37,10 @@ ) Choices = Union[ - List["OptionChoice"], - List[ApplicationCommandOptionChoiceValue], - Dict[str, ApplicationCommandOptionChoiceValue], - List[Localized[str]], + Sequence["OptionChoice"], + Sequence[ApplicationCommandOptionChoiceValue], + Mapping[str, ApplicationCommandOptionChoiceValue], + Sequence[Localized[str]], ] APIApplicationCommand = Union["APIUserCommand", "APIMessageCommand", "APISlashCommand"] @@ -179,8 +179,8 @@ class Option: The option type, e.g. :class:`OptionType.user`. required: :class:`bool` Whether this option is required. - choices: Union[List[:class:`OptionChoice`], List[Union[:class:`str`, :class:`int`]], Dict[:class:`str`, Union[:class:`str`, :class:`int`]]] - The list of option choices. + choices: Union[Sequence[:class:`OptionChoice`], Sequence[Union[:class:`str`, :class:`int`, :class:`float`]], Mapping[:class:`str`, Union[:class:`str`, :class:`int`, :class:`float`]]] + The pre-defined choices for this option. options: List[:class:`Option`] The list of sub options. Normally you don't have to specify it directly, instead consider using ``@main_cmd.sub_command`` or ``@main_cmd.sub_command_group`` decorators. @@ -214,7 +214,7 @@ class Option: required: :class:`bool` Whether this option is required. choices: List[:class:`OptionChoice`] - The list of option choices. + The list of pre-defined choices. options: List[:class:`Option`] The list of sub options. Normally you don't have to specify it directly, instead consider using ``@main_cmd.sub_command`` or ``@main_cmd.sub_command_group`` decorators. @@ -304,6 +304,9 @@ def __init__( if autocomplete: raise TypeError("can not specify both choices and autocomplete args") + if isinstance(choices, str): # str matches `Sequence[str]`, but isn't meant to be used + raise TypeError("choices argument should be a list/sequence or dict, not str") + if isinstance(choices, Mapping): self.choices = [OptionChoice(name, value) for name, value in choices.items()] else: @@ -370,7 +373,7 @@ def from_dict(cls, data: ApplicationCommandOptionPayload) -> Option: def add_choice( self, name: LocalizedRequired, - value: Union[str, int], + value: ApplicationCommandOptionChoiceValue, ) -> None: """Adds an OptionChoice to the list of current choices, parameters are the same as for :class:`OptionChoice`. @@ -388,7 +391,7 @@ def add_option( description: LocalizedOptional = None, type: Optional[OptionType] = None, required: bool = False, - choices: Optional[List[OptionChoice]] = None, + choices: Optional[Choices] = None, options: Optional[list] = None, channel_types: Optional[List[ChannelType]] = None, autocomplete: bool = False, @@ -884,7 +887,7 @@ def add_option( description: LocalizedOptional = None, type: Optional[OptionType] = None, required: bool = False, - choices: Optional[List[OptionChoice]] = None, + choices: Optional[Choices] = None, options: Optional[list] = None, channel_types: Optional[List[ChannelType]] = None, autocomplete: bool = False, diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 9114b8b353..e472c1ae13 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -6,6 +6,7 @@ import asyncio import collections.abc +import copy import inspect import itertools import math @@ -451,8 +452,8 @@ class ParamInfo: .. versionchanged:: 2.5 Added support for localizations. - choices: Union[List[:class:`.OptionChoice`], List[Union[:class:`str`, :class:`int`]], Dict[:class:`str`, Union[:class:`str`, :class:`int`]]] - The list of choices of this slash command option. + choices: Union[Sequence[:class:`.OptionChoice`], Sequence[Union[:class:`str`, :class:`int`, :class:`float`]], Mapping[:class:`str`, Union[:class:`str`, :class:`int`, :class:`float`]]] + The pre-defined choices for this option. ge: :class:`float` The lowest allowed value for this option. le: :class:`float` @@ -554,7 +555,7 @@ def copy(self) -> Self: ins.converter = self.converter ins.convert_default = self.convert_default ins.autocomplete = self.autocomplete - ins.choices = self.choices.copy() + ins.choices = copy.copy(self.choices) ins.type = self.type ins.channel_types = self.channel_types.copy() ins.max_value = self.max_value @@ -1155,8 +1156,8 @@ def Param( .. versionchanged:: 2.5 Added support for localizations. - choices: Union[List[:class:`.OptionChoice`], List[Union[:class:`str`, :class:`int`]], Dict[:class:`str`, Union[:class:`str`, :class:`int`]]] - A list of choices for this option. + choices: Union[Sequence[:class:`.OptionChoice`], Sequence[Union[:class:`str`, :class:`int`, :class:`float`]], Mapping[:class:`str`, Union[:class:`str`, :class:`int`, :class:`float`]]] + The pre-defined choices for this slash command option. converter: Callable[[:class:`.ApplicationCommandInteraction`, Any], Any] A function that will convert the original input to a desired format. Kwarg aliases: ``conv``. diff --git a/disnake/interactions/base.py b/disnake/interactions/base.py index 01637be96a..f0058d2cc8 100644 --- a/disnake/interactions/base.py +++ b/disnake/interactions/base.py @@ -1260,8 +1260,8 @@ async def autocomplete(self, *, choices: Choices) -> None: Parameters ---------- - choices: Union[List[:class:`OptionChoice`], List[Union[:class:`str`, :class:`int`]], Dict[:class:`str`, Union[:class:`str`, :class:`int`]]] - The list of choices to suggest. + choices: Union[Sequence[:class:`OptionChoice`], Sequence[Union[:class:`str`, :class:`int`, :class:`float`]], Mapping[:class:`str`, Union[:class:`str`, :class:`int`, :class:`float`]]] + The choices to suggest. Raises ------ @@ -1277,6 +1277,9 @@ async def autocomplete(self, *, choices: Choices) -> None: if isinstance(choices, Mapping): choices_data = [{"name": n, "value": v} for n, v in choices.items()] else: + if isinstance(choices, str): # str matches `Sequence[str]`, but isn't meant to be used + raise TypeError("choices argument should be a list/sequence or dict, not str") + choices_data = [] value: ApplicationCommandOptionChoicePayload i18n = self._parent.client.i18n From 85cf39375f9d1c43014cb96c785f25e7809b58be Mon Sep 17 00:00:00 2001 From: shiftinv <8530778+shiftinv@users.noreply.github.com> Date: Fri, 8 Dec 2023 15:36:44 +0100 Subject: [PATCH 7/7] feat(commands): support 3.12's `type` statement and `TypeAliasType` (#1128) --- changelog/1128.feature.rst | 1 + disnake/utils.py | 48 +++++++++++++++++++++---- tests/test_utils.py | 69 ++++++++++++++++++++++++++++++++++-- tests/utils_helper_module.py | 26 ++++++++++++++ 4 files changed, 135 insertions(+), 9 deletions(-) create mode 100644 changelog/1128.feature.rst create mode 100644 tests/utils_helper_module.py diff --git a/changelog/1128.feature.rst b/changelog/1128.feature.rst new file mode 100644 index 0000000000..66c35b1935 --- /dev/null +++ b/changelog/1128.feature.rst @@ -0,0 +1 @@ +|commands| Support Python 3.12's ``type`` statement and :class:`py:typing.TypeAliasType` annotations in command signatures. diff --git a/disnake/utils.py b/disnake/utils.py index 6fa6ae82d7..9061cd0f61 100644 --- a/disnake/utils.py +++ b/disnake/utils.py @@ -1122,6 +1122,24 @@ def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]: return tuple(p for p in parameters if p is not none_cls) + (none_cls,) +def _resolve_typealiastype( + tp: Any, globals: Dict[str, Any], locals: Dict[str, Any], cache: Dict[str, Any] +): + # Use __module__ to get the (global) namespace in which the type alias was defined. + if mod := sys.modules.get(tp.__module__): + mod_globals = mod.__dict__ + if mod_globals is not globals or mod_globals is not locals: + # if the namespace changed (usually when a TypeAliasType was imported from a different module), + # drop the cache since names can resolve differently now + cache = {} + globals = locals = mod_globals + + # Accessing `__value__` automatically evaluates the type alias in the annotation scope. + # (recurse to resolve possible forwardrefs, aliases, etc.) + return evaluate_annotation(tp.__value__, globals, locals, cache) + + +# FIXME: this should be split up into smaller functions for clarity and easier maintenance def evaluate_annotation( tp: Any, globals: Dict[str, Any], @@ -1147,23 +1165,31 @@ def evaluate_annotation( cache[tp] = evaluated return evaluated + # GenericAlias / UnionType if hasattr(tp, "__args__"): - implicit_str = True - is_literal = False - orig_args = args = tp.__args__ if not hasattr(tp, "__origin__"): if tp.__class__ is UnionType: - converted = Union[args] # type: ignore + converted = Union[tp.__args__] # type: ignore return evaluate_annotation(converted, globals, locals, cache) return tp - if tp.__origin__ is Union: + + implicit_str = True + is_literal = False + orig_args = args = tp.__args__ + orig_origin = origin = tp.__origin__ + + # origin can be a TypeAliasType too, resolve it and continue + if hasattr(origin, "__value__"): + origin = _resolve_typealiastype(origin, globals, locals, cache) + + if origin is Union: try: if args.index(type(None)) != len(args) - 1: args = normalise_optional_params(tp.__args__) except ValueError: pass - if tp.__origin__ is Literal: + if origin is Literal: if not PY_310: args = flatten_literal_params(tp.__args__) implicit_str = False @@ -1179,13 +1205,21 @@ def evaluate_annotation( ): raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") + if origin != orig_origin: + # we can't use `copy_with` in this case, so just skip all of the following logic + return origin[evaluated_args] + if evaluated_args == orig_args: return tp try: return tp.copy_with(evaluated_args) except AttributeError: - return tp.__origin__[evaluated_args] + return origin[evaluated_args] + + # TypeAliasType, 3.12+ + if hasattr(tp, "__value__"): + return _resolve_typealiastype(tp, globals, locals, cache) return tp diff --git a/tests/test_utils.py b/tests/test_utils.py index d767264a95..46237c2019 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,7 @@ import warnings from dataclasses import dataclass from datetime import timedelta, timezone -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, TypeVar, Union from unittest import mock import pytest @@ -18,7 +18,13 @@ import disnake from disnake import utils -from . import helpers +from . import helpers, utils_helper_module + +if TYPE_CHECKING: + from typing_extensions import TypeAliasType +elif sys.version_info >= (3, 12): + # non-3.12 tests shouldn't be using this + from typing import TypeAliasType def test_missing() -> None: @@ -785,6 +791,65 @@ def test_resolve_annotation_literal() -> None: utils.resolve_annotation(Literal[timezone.utc, 3], globals(), locals(), {}) # type: ignore +@pytest.mark.skipif(sys.version_info < (3, 12), reason="syntax requires py3.12") +class TestResolveAnnotationTypeAliasType: + def test_simple(self) -> None: + # this is equivalent to `type CoolList = List[int]` + CoolList = TypeAliasType("CoolList", List[int]) + assert utils.resolve_annotation(CoolList, globals(), locals(), {}) == List[int] + + def test_generic(self) -> None: + # this is equivalent to `type CoolList[T] = List[T]; CoolList[int]` + T = TypeVar("T") + CoolList = TypeAliasType("CoolList", List[T], type_params=(T,)) + + annotation = CoolList[int] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[int] + + # alias and arg in local scope + def test_forwardref_local(self) -> None: + T = TypeVar("T") + IntOrStr = Union[int, str] + CoolList = TypeAliasType("CoolList", List[T], type_params=(T,)) + + annotation = CoolList["IntOrStr"] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[IntOrStr] + + # alias and arg in other module scope + def test_forwardref_module(self) -> None: + resolved = utils.resolve_annotation( + utils_helper_module.ListWithForwardRefAlias, globals(), locals(), {} + ) + assert resolved == List[Union[int, str]] + + # combination of the previous two, alias in other module scope and arg in local scope + def test_forwardref_mixed(self) -> None: + LocalIntOrStr = Union[int, str] + + annotation = utils_helper_module.GenericListAlias["LocalIntOrStr"] + assert utils.resolve_annotation(annotation, globals(), locals(), {}) == List[LocalIntOrStr] + + # two different forwardrefs with same name + def test_forwardref_duplicate(self) -> None: + DuplicateAlias = int + + # first, resolve an annotation where `DuplicateAlias` resolves to the local int + cache = {} + assert ( + utils.resolve_annotation(List["DuplicateAlias"], globals(), locals(), cache) + == List[int] + ) + + # then, resolve an annotation where the globalns changes and `DuplicateAlias` resolves to something else + # (i.e. this should not resolve to `List[int]` despite {"DuplicateAlias": int} in the cache) + assert ( + utils.resolve_annotation( + utils_helper_module.ListWithDuplicateAlias, globals(), locals(), cache + ) + == List[str] + ) + + @pytest.mark.parametrize( ("dt", "style", "expected"), [ diff --git a/tests/utils_helper_module.py b/tests/utils_helper_module.py new file mode 100644 index 0000000000..7711e861b8 --- /dev/null +++ b/tests/utils_helper_module.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: MIT + +"""Separate module file for some test_utils.py type annotation tests.""" + +import sys +from typing import TYPE_CHECKING, List, TypeVar, Union + +version = sys.version_info # assign to variable to trick pyright + +if TYPE_CHECKING: + from typing_extensions import TypeAliasType +elif version >= (3, 12): + # non-3.12 tests shouldn't be using this + from typing import TypeAliasType + +if version >= (3, 12): + CoolUniqueIntOrStrAlias = Union[int, str] + ListWithForwardRefAlias = TypeAliasType( + "ListWithForwardRefAlias", List["CoolUniqueIntOrStrAlias"] + ) + + T = TypeVar("T") + GenericListAlias = TypeAliasType("GenericListAlias", List[T], type_params=(T,)) + + DuplicateAlias = str + ListWithDuplicateAlias = TypeAliasType("ListWithDuplicateAlias", List["DuplicateAlias"])