Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max reconnect attempts #247

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ async def main():
query=kw
)

connection = connection_class(url, loop=loop)
connection = connection_class(url, loop=loop, **kwargs)

await connection.connect(
timeout=timeout, client_properties=client_properties
Expand Down
5 changes: 5 additions & 0 deletions aio_pika/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
pass


class MaxReconnectAttemptsReached(Exception):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, this class should be a subclass of ConnectionError

pass


__all__ = (
'AMQPChannelError',
'AMQPConnectionError',
Expand All @@ -53,6 +57,7 @@ class QueueEmpty(AMQPError, asyncio.QueueEmpty):
'DuplicateConsumerTag',
'IncompatibleProtocolError',
'InvalidFrameError',
'MaxReconnectAttemptsReached',
'MessageProcessError',
'MethodNotImplemented',
'ProbableAuthenticationError',
Expand Down
20 changes: 19 additions & 1 deletion aio_pika/robust_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Type

from aiormq.connection import parse_bool, parse_int
from .exceptions import CONNECTION_EXCEPTIONS
from .exceptions import CONNECTION_EXCEPTIONS, MaxReconnectAttemptsReached
from .connection import Connection, connect, ConnectionType
from .tools import CallbackCollection
from .types import TimeoutType
Expand All @@ -29,6 +29,7 @@ class RobustConnection(Connection):

CHANNEL_CLASS = RobustChannel
KWARGS_TYPES = (
('max_reconnect_attempts', parse_int, '0'),
('reconnect_interval', parse_int, '5'),
('fail_fast', parse_bool, '1'),
)
Expand All @@ -43,7 +44,9 @@ def __init__(self, url, loop=None, **kwargs):
self.fail_fast = self.kwargs['fail_fast']

self.__channels = set()
self._reconnect_attempt = None
self._reconnect_callbacks = CallbackCollection()
self._stop_callbacks = CallbackCollection()
self._closed = False

@property
Expand Down Expand Up @@ -77,6 +80,9 @@ def add_reconnect_callback(self, callback: Callable[[], None]):

self._reconnect_callbacks.add(callback)

def add_stop_callback(self, callback: Callable[[Exception], None]):
self._stop_callbacks.add(callback)

async def connect(self, timeout: TimeoutType = None, **kwargs):
if kwargs:
# Store connect kwargs for reconnects
Expand Down Expand Up @@ -104,6 +110,16 @@ async def reconnect(self):
if self.is_closed:
return

if self.kwargs['max_reconnect_attempts'] > 0:
if self._reconnect_attempt is None:
self._reconnect_attempt = 1
else:
self._reconnect_attempt += 1

if self._reconnect_attempt > self.kwargs['max_reconnect_attempts']:
await self.close(MaxReconnectAttemptsReached())
return

try:
await super().connect()
except CONNECTION_EXCEPTIONS:
Expand Down Expand Up @@ -131,6 +147,7 @@ def channel(self, channel_number: int = None,
return channel

async def _on_reconnect(self):
self._reconnect_attempt = None
for number, channel in self._channels.items():
try:
await channel.on_reconnect(self, number)
Expand All @@ -151,6 +168,7 @@ async def close(self, exc=asyncio.CancelledError):
return

self._closed = True
self._stop_callbacks(exc)

if self.connection is None:
return
Expand Down
2 changes: 1 addition & 1 deletion tests/test_amqp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ async def test_on_return_raises(self):
)

for _ in range(100):
with pytest.raises(aio_pika.exceptions.DeliveryError) as e:
with pytest.raises(aio_pika.exceptions.DeliveryError):
await channel.default_exchange.publish(
Message(body=body), routing_key=queue_name,
)
Expand Down
44 changes: 40 additions & 4 deletions tests/test_amqp_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiormq import ChannelLockedResource

from aio_pika import connect_robust, Message
from aio_pika.exceptions import MaxReconnectAttemptsReached
from aio_pika.robust_channel import RobustChannel
from aio_pika.robust_connection import RobustConnection
from aio_pika.robust_queue import RobustQueue
Expand All @@ -27,6 +28,7 @@ def __init__(self, *, loop, shost='127.0.0.1', sport,
self.src_port = sport
self.dst_host = dhost
self.dst_port = dport
self._run_task = None
self.connections = set()

async def _pipe(self, reader: asyncio.StreamReader,
Expand Down Expand Up @@ -54,12 +56,19 @@ async def handle_client(self, creader: asyncio.StreamReader,
])

async def start(self):
return await asyncio.start_server(
self._run_task = await asyncio.start_server(
self.handle_client,
host=self.src_host,
port=self.src_port,
loop=self.loop,
)
return self._run_task

async def stop(self):
assert self._run_task is not None
self._run_task.close()
await self.disconnect()
self._run_task = None

async def disconnect(self):
tasks = list()
Expand All @@ -72,7 +81,8 @@ async def close(writer):
writer = self.connections.pop() # type: asyncio.StreamWriter
tasks.append(self.loop.create_task(close(writer)))

await asyncio.wait(tasks)
if tasks:
await asyncio.wait(tasks)


class TestCase(AMQPTestCase):
Expand All @@ -84,7 +94,7 @@ def get_unused_port() -> int:
sock.close()
return port

async def create_connection(self, cleanup=True):
async def create_connection(self, cleanup=True, max_reconnect_attempts=0):
self.proxy = Proxy(
dhost=AMQP_URL.host,
dport=AMQP_URL.port,
Expand All @@ -98,7 +108,11 @@ async def create_connection(self, cleanup=True):
self.proxy.src_host
).with_port(
self.proxy.src_port
).update_query(reconnect_interval=1)
).update_query(
reconnect_interval=1
).update_query(
max_reconnect_attempts=max_reconnect_attempts
)

client = await connect_robust(str(url), loop=self.loop)

Expand Down Expand Up @@ -210,6 +224,28 @@ async def reader():

assert len(shared) == 10

async def test_robust_reconnect_max_attempts(self):
client = await self.create_connection(max_reconnect_attempts=2)
self.assertIsInstance(client, RobustConnection)

first_close = asyncio.Future()
stopped = asyncio.Future()

def stop_callback(exc):
assert isinstance(exc, MaxReconnectAttemptsReached)
stopped.set_result(True)

def close_callback(f):
first_close.set_result(True)

client.add_stop_callback(stop_callback)
client.connection.closing.add_done_callback(close_callback)
await self.proxy.stop()
await first_close
# 1 interval before first try and 2 after attempts
await asyncio.wait_for(stopped,
timeout=client.reconnect_interval * 3 + 0.1)

async def test_channel_locked_resource2(self):
ch1 = await self.create_channel()
ch2 = await self.create_channel()
Expand Down