From d8d852db3e38a65619468c9da812fcd10ac5781f Mon Sep 17 00:00:00 2001 From: Nikita Kharlov Date: Tue, 10 Sep 2019 12:01:54 +0300 Subject: [PATCH] max reconnect attempts --- aio_pika/connection.py | 2 +- aio_pika/exceptions.py | 5 +++++ aio_pika/robust_connection.py | 29 +++++++++++++++++++++++- tests/test_amqp_robust.py | 42 ++++++++++++++++++++++++++++++----- 4 files changed, 71 insertions(+), 7 deletions(-) diff --git a/aio_pika/connection.py b/aio_pika/connection.py index bfbca1fe..926e0fcf 100644 --- a/aio_pika/connection.py +++ b/aio_pika/connection.py @@ -294,7 +294,7 @@ async def main(): query=kw ) - connection = connection_class(url, loop=loop) + connection = connection_class(url, loop=loop, **kwargs) await connection.connect(timeout=timeout) return connection diff --git a/aio_pika/exceptions.py b/aio_pika/exceptions.py index ba10e5c0..a1d3bd0f 100644 --- a/aio_pika/exceptions.py +++ b/aio_pika/exceptions.py @@ -39,6 +39,10 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty): pass +class MaxReconnectAttemptsReached(Exception): + pass + + __all__ = ( 'AMQPChannelError', 'AMQPConnectionError', @@ -51,6 +55,7 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty): 'DuplicateConsumerTag', 'IncompatibleProtocolError', 'InvalidFrameError', + 'MaxReconnectAttemptsReached', 'MessageProcessError', 'MethodNotImplemented', 'ProbableAuthenticationError', diff --git a/aio_pika/robust_connection.py b/aio_pika/robust_connection.py index 9feb15f6..25095ec6 100644 --- a/aio_pika/robust_connection.py +++ b/aio_pika/robust_connection.py @@ -4,7 +4,7 @@ from typing import Callable, Type from aiormq.connection import parse_bool, parse_int -from .exceptions import CONNECTION_EXCEPTIONS +from .exceptions import CONNECTION_EXCEPTIONS, MaxReconnectAttemptsReached from .connection import Connection, connect, ConnectionType from .tools import CallbackCollection from .types import TimeoutType @@ -29,6 +29,7 @@ class RobustConnection(Connection): CHANNEL_CLASS = RobustChannel KWARGS_TYPES = ( + ('max_reconnect_attempts', parse_int, '0'), ('reconnect_interval', parse_int, '5'), ('fail_fast', parse_bool, '1'), ) @@ -41,8 +42,13 @@ def __init__(self, url, loop=None, **kwargs): self.reconnect_interval = self.kwargs['reconnect_interval'] self.fail_fast = self.kwargs['fail_fast'] + self._stop_future = self.loop.create_future() + self._stop_future.add_done_callback(self._on_stop) + self.__channels = set() + self._reconnect_attempt = None self._on_reconnect_callbacks = CallbackCollection() + self._on_stop_callbacks = CallbackCollection() self._closed = False @property @@ -63,11 +69,18 @@ def _on_connection_close(self, connection, closing, *args, **kwargs): super()._on_connection_close(connection, closing) + if isinstance(closing.exception(), MaxReconnectAttemptsReached): + return + self.loop.call_later( self.reconnect_interval, lambda: self.loop.create_task(self.reconnect()) ) + def _on_stop(self, future): + for cb in self._on_stop_callbacks: + cb(future.exception()) + def add_reconnect_callback(self, callback: Callable[[], None]): """ Add callback which will be called after reconnect. @@ -76,6 +89,9 @@ def add_reconnect_callback(self, callback: Callable[[], None]): self._on_reconnect_callbacks.add(callback) + def add_stop_callback(self, callback: Callable[[Exception], None]): + self._on_stop_callbacks.add(callback) + async def connect(self, timeout: TimeoutType = None): while True: try: @@ -97,6 +113,16 @@ async def reconnect(self): if self.is_closed: return + if self.kwargs['max_reconnect_attempts'] > 0: + if self._reconnect_attempt is None: + self._reconnect_attempt = 1 + else: + self._reconnect_attempt += 1 + + if self._reconnect_attempt > self.kwargs['max_reconnect_attempts']: + self._stop_future.set_exception(MaxReconnectAttemptsReached()) + return + try: await super().connect() except CONNECTION_EXCEPTIONS: @@ -124,6 +150,7 @@ def channel(self, channel_number: int = None, return channel async def _on_reconnect(self): + self._reconnect_attempt = None for number, channel in self._channels.items(): try: await channel.on_reconnect(self, number) diff --git a/tests/test_amqp_robust.py b/tests/test_amqp_robust.py index a65e6bd1..7ba81a24 100644 --- a/tests/test_amqp_robust.py +++ b/tests/test_amqp_robust.py @@ -8,6 +8,7 @@ from aiormq import ChannelLockedResource from aio_pika import connect_robust, Message +from aio_pika.exceptions import MaxReconnectAttemptsReached from aio_pika.robust_channel import RobustChannel from aio_pika.robust_connection import RobustConnection from aio_pika.robust_queue import RobustQueue @@ -27,6 +28,7 @@ def __init__(self, *, loop, shost='127.0.0.1', sport, self.src_port = sport self.dst_host = dhost self.dst_port = dport + self._run_task = None self.connections = set() async def _pipe(self, reader: asyncio.StreamReader, @@ -54,14 +56,18 @@ async def handle_client(self, creader: asyncio.StreamReader, ]) async def start(self): - result = await asyncio.start_server( + self._run_task = await asyncio.start_server( self.handle_client, host=self.src_host, port=self.src_port, loop=self.loop, ) - return result + async def stop(self): + assert self._run_task is not None + self._run_task.close() + await self.disconnect() + self._run_task = None async def disconnect(self): tasks = list() @@ -74,7 +80,8 @@ async def close(writer): writer = self.connections.pop() # type: asyncio.StreamWriter tasks.append(self.loop.create_task(close(writer))) - await asyncio.wait(tasks) + if tasks: + await asyncio.wait(tasks) class TestCase(AMQPTestCase): @@ -86,7 +93,7 @@ def get_unused_port() -> int: sock.close() return port - async def create_connection(self, cleanup=True): + async def create_connection(self, cleanup=True, max_reconnect_attempts=0): self.proxy = Proxy( dhost=AMQP_URL.host, dport=AMQP_URL.port, @@ -100,7 +107,11 @@ async def create_connection(self, cleanup=True): self.proxy.src_host ).with_port( self.proxy.src_port - ).update_query(reconnect_interval=1) + ).update_query( + reconnect_interval=1 + ).update_query( + max_reconnect_attempts=max_reconnect_attempts + ) client = await connect_robust(str(url), loop=self.loop) @@ -212,6 +223,27 @@ async def reader(): assert len(shared) == 10 + async def test_robust_reconnect_max_attempts(self): + client = await self.create_connection(max_reconnect_attempts=2) + self.assertIsInstance(client, RobustConnection) + + first_close = asyncio.Future() + stopped = asyncio.Future() + + def stop_callback(exc): + assert isinstance(exc, MaxReconnectAttemptsReached) + stopped.set_result(True) + + def close_callback(f): + first_close.set_result(True) + + client.add_stop_callback(stop_callback) + client.connection.closing.add_done_callback(close_callback) + await self.proxy.stop() + await first_close + # 1 interval before first try and 2 after attempts + await asyncio.wait_for(stopped, timeout=client.reconnect_interval * 3 + 0.1) + async def test_channel_locked_resource2(self): ch1 = await self.create_channel() ch2 = await self.create_channel()