diff --git a/naff/api/events/processors/guild_events.py b/naff/api/events/processors/guild_events.py index 1519ce945..e397b6d9b 100644 --- a/naff/api/events/processors/guild_events.py +++ b/naff/api/events/processors/guild_events.py @@ -39,7 +39,7 @@ async def _on_raw_guild_create(self, event: "RawGatewayEvent") -> None: self._guild_event.set() - if self.fetch_members: # noqa + if self.fetch_members and not guild.chunked.is_set(): # noqa # delays events until chunking has completed await guild.chunk() diff --git a/naff/api/gateway/gateway.py b/naff/api/gateway/gateway.py index 52effc601..149cb4759 100644 --- a/naff/api/gateway/gateway.py +++ b/naff/api/gateway/gateway.py @@ -7,7 +7,7 @@ from typing import TypeVar, TYPE_CHECKING from naff.api import events -from naff.client.const import logger, MISSING +from naff.client.const import logger, MISSING, __api_version__ from naff.client.utils.input_utils import OverriddenJson from naff.client.utils.serializer import dict_filter_none from naff.models.discord.enums import Status @@ -48,7 +48,6 @@ class GatewayClient(WebsocketClient): Multiple `WebsocketClient` instances can be used to implement same-process sharding. Attributes: - buffer: A buffer to hold incoming data until its complete sequence: The sequence of this connection session_id: The session ID of this connection @@ -83,7 +82,7 @@ def __init__(self, state: "ConnectionState", shard: tuple[int, int]) -> None: self._ready = asyncio.Event() self._close_gateway = asyncio.Event() - # Santity check, it is extremely important that an instance isn't reused. + # Sanity check, it is extremely important that an instance isn't reused. self._entered = False async def __aenter__(self: SELF) -> SELF: @@ -177,6 +176,7 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None: match op: case OPCODE.HEARTBEAT: + logger.debug("Received heartbeat request from gateway") return await self.send_heartbeat() case OPCODE.HEARTBEAT_ACK: @@ -192,12 +192,12 @@ async def dispatch_opcode(self, data, op: OPCODE) -> None: return self._acknowledged.set() case OPCODE.RECONNECT: - logger.info("Gateway requested reconnect. Reconnecting...") + logger.debug("Gateway requested reconnect. Reconnecting...") return await self.reconnect(resume=True, url=self.ws_resume_url) case OPCODE.INVALIDATE_SESSION: logger.warning("Gateway has invalidated session! Reconnecting...") - return await self.reconnect(resume=data, url=self.ws_resume_url if data else None) + return await self.reconnect() case _: return logger.debug(f"Unhandled OPCODE: {op} = {OPCODE(op).name}") @@ -209,7 +209,9 @@ async def dispatch_event(self, data, seq, event) -> None: self._trace = data.get("_trace", []) self.sequence = seq self.session_id = data["session_id"] - self.ws_resume_url = data["resume_gateway_url"] + self.ws_resume_url = ( + f"{data['resume_gateway_url']}?encoding=json&v={__api_version__}&compress=zlib-stream" + ) logger.info(f"Shard {self.shard[0]} has connected to gateway!") logger.debug(f"Session ID: {self.session_id} Trace: {self._trace}") # todo: future polls, improve guild caching here. run the debugger. you'll see why @@ -287,7 +289,7 @@ async def _resume_connection(self) -> None: logger.debug(f"{self.shard[0]} is attempting to resume a connection") async def send_heartbeat(self) -> None: - await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, True) + await self.send_json({"op": OPCODE.HEARTBEAT, "d": self.sequence}, bypass=True) logger.debug(f"❤ Shard {self.shard[0]} is sending a Heartbeat") async def change_presence(self, activity=None, status: Status = Status.ONLINE, since=None) -> None: diff --git a/naff/api/gateway/websocket.py b/naff/api/gateway/websocket.py index 76bee2b2e..7c97362b7 100644 --- a/naff/api/gateway/websocket.py +++ b/naff/api/gateway/websocket.py @@ -1,5 +1,6 @@ import asyncio import collections +import random import time import zlib from abc import abstractmethod @@ -275,12 +276,12 @@ async def _start_bee_gees(self) -> None: if self.heartbeat_interval is None: raise RuntimeError - # try: - # await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5)) - # except asyncio.TimeoutError: - # pass - # else: - # return + try: + await asyncio.wait_for(self._kill_bee_gees.wait(), timeout=self.heartbeat_interval * random.uniform(0, 0.5)) + except asyncio.TimeoutError: + pass + else: + return logger.debug(f"Sending heartbeat every {self.heartbeat_interval} seconds") while not self._kill_bee_gees.is_set(): diff --git a/naff/client/errors.py b/naff/client/errors.py index 821948b8b..30e602c2e 100644 --- a/naff/client/errors.py +++ b/naff/client/errors.py @@ -104,6 +104,11 @@ def __init__( self.text = data super().__init__(f"{self.status}|{self.response.reason}: {f'({self.code}) ' if self.code else ''}{self.text}") + def __str__(self) -> str: + errors = self.search_for_message(self.errors) + out = f"HTTPException: {self.status}|{self.response.reason}: " + "\n".join(errors) + return out + @staticmethod def search_for_message(errors: dict, lookup: Optional[dict] = None) -> list[str]: """ diff --git a/naff/client/smart_cache.py b/naff/client/smart_cache.py index df7e435b5..2eeed3493 100644 --- a/naff/client/smart_cache.py +++ b/naff/client/smart_cache.py @@ -232,7 +232,8 @@ def delete_member(self, guild_id: "Snowflake_Type", user_id: "Snowflake_Type") - guild_id = to_snowflake(guild_id) if member := self.member_cache.pop((guild_id, user_id), None): - member.guild._member_ids.discard(user_id) + if member.guild: + member.guild._member_ids.discard(user_id) self.delete_user_guild(user_id, guild_id) diff --git a/naff/models/discord/auto_mod.py b/naff/models/discord/auto_mod.py index e9d2973bc..c32fe0bae 100644 --- a/naff/models/discord/auto_mod.py +++ b/naff/models/discord/auto_mod.py @@ -139,6 +139,13 @@ class KeywordPresetTrigger(BaseTrigger): ) +@define() +class MentionSpamTrigger(BaseTrigger): + """A trigger that checks if content contains more mentions than allowed""" + + mention_total_limit: int = field(default=3, repr=True, metadata=docs("The maximum number of mentions allowed")) + + @define() class BlockMessage(BaseAction): """blocks the content of a message according to the rule""" @@ -320,4 +327,5 @@ def message(self) -> "Optional[Message]": AutoModTriggerType.KEYWORD: KeywordTrigger, AutoModTriggerType.HARMFUL_LINK: HarmfulLinkFilter, AutoModTriggerType.KEYWORD_PRESET: KeywordPresetTrigger, + AutoModTriggerType.MENTION_SPAM: MentionSpamTrigger, } diff --git a/naff/models/discord/enums.py b/naff/models/discord/enums.py index eb07a593c..c7b659a9c 100644 --- a/naff/models/discord/enums.py +++ b/naff/models/discord/enums.py @@ -865,20 +865,21 @@ class AuditLogEventType(CursedIntEnum): GUILD_HOME_FEATURE_ITEM_UPDATE = 172 -class AutoModTriggerType(IntEnum): +class AutoModTriggerType(CursedIntEnum): KEYWORD = 1 HARMFUL_LINK = 2 SPAM = 3 KEYWORD_PRESET = 4 + MENTION_SPAM = 5 -class AutoModAction(IntEnum): +class AutoModAction(CursedIntEnum): BLOCK_MESSAGE = 1 ALERT_MESSAGE = 2 TIMEOUT_USER = 3 -class AutoModEvent(IntEnum): +class AutoModEvent(CursedIntEnum): MESSAGE_SEND = 1 diff --git a/naff/models/discord/guild.py b/naff/models/discord/guild.py index 8c2e8f8b2..9aba08db6 100644 --- a/naff/models/discord/guild.py +++ b/naff/models/discord/guild.py @@ -1,5 +1,6 @@ import asyncio import time +from asyncio import QueueEmpty from collections import namedtuple from functools import cmp_to_key from typing import List, Optional, Union, Set, Dict, Any, TYPE_CHECKING @@ -119,6 +120,26 @@ def _process_dict(cls, data: Dict[str, Any], client: "Client") -> Dict[str, Any] return super()._process_dict(data, client) +class MemberIterator(AsyncIterator): + def __init__(self, guild: "Guild", limit: int = 0) -> None: + super().__init__(limit) + self.guild = guild + self._more = True + + async def fetch(self) -> list: + if self._more: + expected = self.get_limit + + rcv = await self.guild._client.http.list_members( + self.guild.id, limit=expected, after=self.last["id"] if self.last else MISSING + ) + if not rcv: + raise QueueEmpty + self._more = len(rcv) == expected + return rcv + raise QueueEmpty + + @define() class Guild(BaseGuild): """Guilds in Discord represent an isolated collection of users and channels, and are often referred to as "servers" in the UI.""" @@ -501,31 +522,15 @@ async def edit_nickname(self, new_nickname: Absent[str] = MISSING, reason: Absen async def http_chunk(self) -> None: """Populates all members of this guild using the REST API.""" start_time = time.perf_counter() - members = [] - - # request all guild members - after = MISSING - while True: - if members: - after = members[-1]["user"]["id"] - rcv: list = await self._client.http.list_members(self.id, limit=1000, after=after) - members.extend(rcv) - if len(rcv) < 1000: - # we're done - break - - # process all members - s = time.monotonic() - for member in members: + + iterator = MemberIterator(self) + async for member in iterator: self._client.cache.place_member_data(self.id, member) - if (time.monotonic() - s) > 0.05: - # look, i get this *could* be a thread, but because it needs to modify data in the main thread, - # it is still blocking. So by periodically yielding to the event loop, we can avoid blocking, and still - # process this data properly - await asyncio.sleep(0) - s = time.monotonic() + self.chunked.set() - logger.info(f"Cached {len(members)} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds") + logger.info( + f"Cached {iterator.total_retrieved} members for {self.id} in {time.perf_counter() - start_time:.2f} seconds" + ) async def gateway_chunk(self, wait=True, presences=True) -> None: """ diff --git a/naff/models/misc/iterator.py b/naff/models/misc/iterator.py index d8cc7bbcd..d9f8197e9 100644 --- a/naff/models/misc/iterator.py +++ b/naff/models/misc/iterator.py @@ -36,6 +36,11 @@ def get_limit(self) -> int: """Get how the maximum number of items that should be retrieved.""" return min(self._limit - len(self._retrieved_objects), 100) if self._limit else 100 + @property + def total_retrieved(self) -> int: + """Get the total number of objects this iterator has retrieved.""" + return len(self._retrieved_objects) + async def add_object(self, obj) -> None: """Add an object to iterator's queue.""" return await self._queue.put(obj) diff --git a/pyproject.toml b/pyproject.toml index 72bb814ae..61f02e79d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "naff" -version = "1.9.0" +version = "1.10.0" description = "Not another freaking fork" authors = ["LordOfPolls "]