From 20f4f3c8f23806844cc444ee28930eb505183620 Mon Sep 17 00:00:00 2001 From: GautamKumar Date: Mon, 30 Sep 2024 16:28:33 +0200 Subject: [PATCH] (fix): restarting of session on OsError KurimuzonAkuma/pyrogram#90 --- pyrogram/session/session.py | 227 ++++++++++-------------------------- 1 file changed, 64 insertions(+), 163 deletions(-) diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 5ad6ad1af7..491d81078e 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -20,7 +20,6 @@ import bisect import logging import os -from time import time from hashlib import sha1 from io import BytesIO from typing import Optional @@ -52,14 +51,11 @@ def __init__(self): class Session: START_TIMEOUT = 5 WAIT_TIMEOUT = 15 - REART_TIMEOUT = 5 SLEEP_THRESHOLD = 10 MAX_RETRIES = 10 ACKS_THRESHOLD = 10 PING_INTERVAL = 5 STORED_MSG_IDS_MAX_SIZE = 1000 * 2 - RECONNECT_THRESHOLD = 13 - STOP_RANGE = range(2) TRANSPORT_ERRORS = { 404: "auth key not found", @@ -110,20 +106,12 @@ def __init__( self.recv_task = None self.is_started = asyncio.Event() + self.restart_event = asyncio.Event() self.loop = asyncio.get_event_loop() - self.instant_stop = False # set internally - self.last_reconnect_attempt = None - self.currently_restarting = False - self.currently_stopping = False - async def start(self): while True: - if self.instant_stop: - log.info("session init force stopped (loop)") - return # stop instantly - self.connection = self.client.connection_factory( dc_id=self.dc_id, test_mode=self.test_mode, @@ -183,106 +171,51 @@ async def start(self): log.info("Session started") - async def stop(self, restart: bool = False): - if self.currently_stopping: - return # don't stop twice - if self.instant_stop: - log.info("session stop process stopped") - return # stop doing anything instantly, client is manually handling + async def stop(self): + self.is_started.clear() - try: - self.currently_stopping = True - self.is_started.clear() - self.stored_msg_ids.clear() - - if restart: - self.instant_stop = True # tell all funcs that we want to stop - - self.ping_task_event.set() - for _ in self.STOP_RANGE: - try: - if self.ping_task is not None: - await asyncio.wait_for( - self.ping_task, timeout=self.REART_TIMEOUT - ) - break - except TimeoutError: - self.ping_task.cancel() - continue # next stage - self.ping_task_event.clear() + self.stored_msg_ids.clear() + + self.ping_task_event.set() + + if self.ping_task is not None: + await self.ping_task + + self.ping_task_event.clear() + await self.connection.close() + + if self.recv_task: + await self.recv_task + + if not self.is_media and callable(self.client.disconnect_handler): try: - await asyncio.wait_for( - self.connection.close(), timeout=self.REART_TIMEOUT - ) + await self.client.disconnect_handler(self.client) except Exception as e: log.exception(e) - for _ in self.STOP_RANGE: - try: - if self.recv_task: - await asyncio.wait_for( - self.recv_task, timeout=self.REART_TIMEOUT - ) - break - except TimeoutError: - self.recv_task.cancel() - continue # next stage - - if not self.is_media and callable(self.client.disconnect_handler): - try: - await self.client.disconnect_handler(self.client) - except Exception as e: - log.exception(e) - - log.info("session stopped") - finally: - self.currently_stopping = False - if restart: - self.instant_stop = False # reset + log.info("Session stopped") async def restart(self): - if self.currently_restarting: - return # don't restart twice - if self.instant_stop: - return # stop instantly - - try: - self.currently_restarting = True - now = time() - if ( - self.last_reconnect_attempt - and (now - self.last_reconnect_attempt) < self.RECONNECT_THRESHOLD - ): - to_wait = self.RECONNECT_THRESHOLD + int( - self.RECONNECT_THRESHOLD - (now - self.last_reconnect_attempt) - ) - log.warning( - "[pyrogram] Client [%s] is reconnecting too frequently, sleeping for %s seconds", - self.client.name, - to_wait - ) - await asyncio.sleep(to_wait) - - self.last_reconnect_attempt = now - await self.stop(restart=True) - await self.start() - finally: - self.currently_restarting = False + self.restart_event.set() + await self.stop() + await self.start() + self.restart_event.clear() async def handle_packet(self, packet): - if self.instant_stop: - log.info("Stopped packet handler") - return # stop instantly - - data = await self.loop.run_in_executor( - pyrogram.crypto_executor, - mtproto.unpack, - BytesIO(packet), - self.session_id, - self.auth_key, - self.auth_key_id - ) + try: + data = await self.loop.run_in_executor( + pyrogram.crypto_executor, + mtproto.unpack, + BytesIO(packet), + self.session_id, + self.auth_key, + self.auth_key_id + ) + except ValueError as e: + log.debug(e) + self.loop.create_task(self.restart()) + return messages = ( data.body.messages @@ -360,17 +293,9 @@ async def handle_packet(self, packet): self.pending_acks.clear() async def ping_worker(self): - if self.instant_stop: - log.info("PingTask force stopped") - return # stop instantly - log.info("PingTask started") while True: - if self.instant_stop: - log.info("PingTask force stopped (loop)") - return # stop instantly - try: await asyncio.wait_for(self.ping_task_event.wait(), self.PING_INTERVAL) except asyncio.TimeoutError: @@ -396,10 +321,6 @@ async def recv_worker(self): log.info("NetworkTask started") while True: - if self.instant_stop: - log.info("NetworkTask force stopped (loop)") - return # stop instantly - packet = await self.connection.recv() if packet is None or len(packet) == 4: @@ -412,10 +333,8 @@ async def recv_worker(self): # "and log in again with your phone number or bot token." # ) log.warning( - "[%s] Server sent transport error: %s (%s)", - self.client.name, - error_code, - Session.TRANSPORT_ERRORS.get(error_code, "unknown error"), + "Server sent transport error: %s (%s)", + error_code, Session.TRANSPORT_ERRORS.get(error_code, "unknown error") ) if self.is_started.is_set(): @@ -431,11 +350,8 @@ async def send( self, data: TLObject, wait_response: bool = True, - timeout: float = WAIT_TIMEOUT, + timeout: float = WAIT_TIMEOUT ): - if self.instant_stop: - return # stop instantly - message = self.msg_factory(data) msg_id = message.msg_id @@ -493,6 +409,11 @@ async def invoke( timeout: float = WAIT_TIMEOUT, sleep_threshold: float = SLEEP_THRESHOLD ): + try: + await asyncio.wait_for(self.is_started.wait(), self.WAIT_TIMEOUT) + except asyncio.TimeoutError: + pass + if isinstance(query, Session.CUR_ALWD_INNR_QRYS): inner_query = query.query else: @@ -500,20 +421,7 @@ async def invoke( query_name = ".".join(inner_query.QUALNAME.split(".")[1:]) - while retries > 0: - # sleep until the restart is performed - if self.currently_restarting: - while self.currently_restarting: - if self.instant_stop: - return # stop instantly - await asyncio.sleep(1) - - if self.instant_stop: - return # stop instantly - - if not self.is_started.is_set(): - await self.is_started.wait() - + while True: try: return await self.send(query, timeout=timeout) except (FloodWait, FloodPremiumWait) as e: @@ -522,12 +430,8 @@ async def invoke( if amount > sleep_threshold >= 0: raise - log.warning( - '[%s] Waiting for %s seconds before continuing (required by "%s")', - self.client.name, - amount, - query_name, - ) + log.warning('[%s] Waiting for %s seconds before continuing (required by "%s")', + self.client.name, amount, query_name) await asyncio.sleep(amount) except ( @@ -550,26 +454,23 @@ async def invoke( ): raise e from None - if (isinstance(e, (OSError, RuntimeError)) and "handler" in str(e)) or ( - isinstance(e, TimeoutError) - ): - (log.warning if retries < 2 else log.info)( - '[%s] [%s] reconnecting session requesting "%s", due to: %s', - self.client.name, - Session.MAX_RETRIES - retries, - query_name, - str(e) or repr(e), - ) + (log.warning if retries < 2 else log.info)( + '[%s] Retrying "%s" due to: %s', + Session.MAX_RETRIES - retries + 1, + query_name, str(e) or repr(e) + ) + + # restart was never being called after Exception block + if not self.restart_event.is_set(): self.loop.create_task(self.restart()) else: - (log.warning if retries < 2 else log.info)( - '[%s] [%s] Retrying "%s" due to: %s', - self.client.name, - Session.MAX_RETRIES - retries, - query_name, - str(e) or repr(e), - ) - - await asyncio.sleep(1) - + # multiple Exceptions can be raised in a row, so we need to wait for the restart to finish + try: + await asyncio.wait_for(self.restart_event.wait(), self.WAIT_TIMEOUT) + except asyncio.TimeoutError: + pass + + await asyncio.sleep(0.5) + + return await self.invoke(query, retries - 1, timeout) raise TimeoutError("Exceeded maximum number of retries")