Skip to content

Commit

Permalink
QueueIterator raised StopAsyncIteration when channel is closed.
Browse files Browse the repository at this point in the history
  • Loading branch information
Darsstar committed Jan 17, 2024
1 parent 8804f3c commit c0e8172
Show file tree
Hide file tree
Showing 11 changed files with 293 additions and 94 deletions.
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
[run]
omit = aio_pika/compat.py
branch = True

[report]
exclude_also =
raise NotImplementedError
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jobs:
tests
env:
FORCE_COLOR: 1
COV_CORE_CONFIG: .coveragerc
- run: poetry run coveralls
env:
COVERALLS_PARALLEL: 'true'
Expand Down
12 changes: 10 additions & 2 deletions aio_pika/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,14 @@ def iterator(self, **kwargs: Any) -> "AbstractQueueIterator":
raise NotImplementedError


class AbstractQueueIterator(AsyncIterable):
class AbstractQueueIterator(AsyncIterable[AbstractIncomingMessage]):
_amqp_queue: AbstractQueue
_queue: asyncio.Queue
_consumer_tag: ConsumerTag
_consume_kwargs: Dict[str, Any]

@abstractmethod
def close(self, *_: Any) -> Awaitable[Any]:
def close(self) -> Awaitable[Any]:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -532,6 +532,10 @@ def is_closed(self) -> bool:
def close(self, exc: Optional[ExceptionType] = None) -> Awaitable[None]:
raise NotImplementedError

@abstractmethod
async def wait(self) -> None:
raise NotImplementedError

@abstractmethod
async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:
raise NotImplementedError
Expand Down Expand Up @@ -742,6 +746,10 @@ def is_closed(self) -> bool:
async def close(self, exc: ExceptionType = asyncio.CancelledError) -> None:
raise NotImplementedError

@abstractmethod
async def wait(self) -> None:
raise NotImplementedError

@abstractmethod
async def connect(self, timeout: TimeoutType = None) -> None:
raise NotImplementedError
Expand Down
31 changes: 23 additions & 8 deletions aio_pika/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def __init__(

self._connection: AbstractConnection = connection

# That's means user closed channel instance explicitly
self._closed: bool = False
self._closed: asyncio.Event = asyncio.Event()

self._channel: Optional[UnderlayChannel] = None
self._channel_number = channel_number
Expand All @@ -89,6 +88,8 @@ def __init__(
self.publisher_confirms = publisher_confirms
self.on_return_raises = on_return_raises

self.close_callbacks.add(self._set_closed_callback)

@property
def is_initialized(self) -> bool:
"""Returns True when the channel has been opened
Expand All @@ -99,7 +100,7 @@ def is_initialized(self) -> bool:
def is_closed(self) -> bool:
"""Returns True when the channel has been closed from the broker
side or after the close() method has been called."""
if not self.is_initialized or self._closed:
if not self.is_initialized or self._closed.is_set():
return True
channel = self._channel
if channel is None:
Expand All @@ -119,8 +120,11 @@ async def close(
return

log.debug("Closing channel %r", self)
self._closed = True
await self._channel.close()
self._closed.set()

async def wait(self) -> None:
await self._closed.wait()

async def get_underlay_channel(self) -> aiormq.abc.AbstractChannel:

Expand Down Expand Up @@ -174,12 +178,12 @@ async def _open(self) -> None:
await channel.close(e)
self._channel = None
raise
self._closed = False
self._closed.clear()

async def initialize(self, timeout: TimeoutType = None) -> None:
if self.is_initialized:
raise RuntimeError("Already initialized")
elif self._closed:
elif self._closed.is_set():
raise RuntimeError("Can't initialize closed channel")

await self._open()
Expand All @@ -197,7 +201,10 @@ async def _on_open(self) -> None:
type=ExchangeType.DIRECT,
)

async def _on_close(self, closing: asyncio.Future) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
try:
exc = closing.exception()
except asyncio.CancelledError as e:
Expand All @@ -207,6 +214,14 @@ async def _on_close(self, closing: asyncio.Future) -> None:
if self._channel and self._channel.channel:
self._channel.channel.on_return_callbacks.discard(self._on_return)

return exc

async def _set_closed_callback(
self,
_: AbstractChannel, exc: BaseException
) -> None:
self._closed.set()

async def _on_initialized(self) -> None:
channel = await self.get_underlay_channel()
channel.on_return_callbacks.add(self._on_return)
Expand All @@ -219,7 +234,7 @@ async def reopen(self) -> None:
await self._open()

def __del__(self) -> None:
self._closed = True
self._closed.set()
self._channel = None

async def declare_exchange(
Expand Down
14 changes: 8 additions & 6 deletions aio_pika/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ class Connection(AbstractConnection):
),
)

_closed: bool
_closed: asyncio.Event

@property
def is_closed(self) -> bool:
return self._closed
return self._closed.is_set()

async def close(
self, exc: Optional[aiormq.abc.ExceptionType] = ConnectionClosed,
Expand All @@ -53,7 +53,10 @@ async def close(
if not transport:
return
await transport.close(exc)
self._closed = True
self._closed.set()

async def wait(self) -> None:
await self._closed.wait()

@classmethod
def _parse_parameters(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -74,7 +77,7 @@ def __init__(
):
self.loop = loop or asyncio.get_event_loop()
self.transport = None
self._closed = False
self._closed = asyncio.Event()
self._close_called = False

self.url = URL(url)
Expand Down Expand Up @@ -201,8 +204,7 @@ async def ready(self) -> None:
def __del__(self) -> None:
if (
self.is_closed or
self.loop.is_closed() or
not hasattr(self, "connection")
self.loop.is_closed()
):
return

Expand Down
87 changes: 76 additions & 11 deletions aio_pika/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,22 @@ class QueueIterator(AbstractQueueIterator):
def consumer_tag(self) -> Optional[ConsumerTag]:
return getattr(self, "_consumer_tag", None)

async def close(self, *_: Any) -> Any:
async def close(self) -> None:
await self._on_close(self._amqp_queue.channel, None)
self._closed.set()

async def _set_closed_callback(
self,
_channel: AbstractChannel,
exc: Optional[BaseException]
) -> None:
self._closed.set()

async def _on_close(
self,
_channel: AbstractChannel,
_exc: Optional[BaseException]
) -> None:
log.debug("Cancelling queue iterator %r", self)

if not hasattr(self, "_consumer_tag"):
Expand All @@ -436,7 +451,7 @@ async def close(self, *_: Any) -> Any:
consumer_tag = self._consumer_tag
del self._consumer_tag

self._amqp_queue.close_callbacks.remove(self.close)
self._amqp_queue.close_callbacks.discard(self._on_close)
await self._amqp_queue.cancel(consumer_tag)

log.debug("Queue iterator %r closed", self)
Expand Down Expand Up @@ -482,9 +497,14 @@ def __init__(self, queue: Queue, **kwargs: Any):
self._consumer_tag: ConsumerTag
self._amqp_queue: AbstractQueue = queue
self._queue = asyncio.Queue()
self._closed = asyncio.Event()
self._consume_kwargs = kwargs

self._amqp_queue.close_callbacks.add(self.close)
self._amqp_queue.close_callbacks.add(self._on_close, weak=True)
self._amqp_queue.close_callbacks.add(
self._set_closed_callback,
weak=True
)

async def on_message(self, message: AbstractIncomingMessage) -> None:
await self._queue.put(message)
Expand Down Expand Up @@ -513,22 +533,67 @@ async def __aexit__(
async def __anext__(self) -> IncomingMessage:
if not hasattr(self, "_consumer_tag"):
await self.consume()

if self._closed.is_set():
raise StopAsyncIteration

message = asyncio.create_task(
self._queue.get(),
name=f"waiting for message from {self}"
)
closed_channel = asyncio.create_task(
self._amqp_queue.channel.wait(),
name=f"waiting for channel {self._amqp_queue.channel} to close "
f"before a message from {self}"
)
closed = asyncio.create_task(
self._closed.wait(),
name=f"waiting for queue iterator to close "
f"before a message from {self}"
)

timeout = self._consume_kwargs.get("timeout")
sleep = asyncio.get_running_loop().create_future()

if timeout is not None:
sleep = asyncio.create_task(
asyncio.sleep(timeout),
name=f"waiting for {self} to timeout after {timeout} seconds"
)
else:
timeout = self.DEFAULT_CLOSE_TIMEOUT

pending = {message, closed_channel, closed, sleep}

try:
return await asyncio.wait_for(
self._queue.get(),
timeout=self._consume_kwargs.get("timeout"),
done, pending = await asyncio.wait(
pending, return_when=asyncio.FIRST_COMPLETED
)
except asyncio.CancelledError:
timeout = self._consume_kwargs.get(
"timeout",
self.DEFAULT_CLOSE_TIMEOUT,
)
# Increase coverage score
pass
finally:
for task in pending:
task.cancel()

await asyncio.wait(pending)

if not message.cancelled():
return message.result()

if not closed.cancelled() or not closed_channel.cancelled():
self._closed.set()
raise StopAsyncIteration

if not sleep.cancelled():
log.info(
"%r closing with timeout %d seconds",
self, timeout,
)
await asyncio.wait_for(self.close(), timeout=timeout)
raise
raise TimeoutError

raise asyncio.CancelledError


__all__ = ("Queue", "QueueIterator", "ConsumerTag")
19 changes: 14 additions & 5 deletions aio_pika/robust_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
self.reopen_callbacks: CallbackCollection = CallbackCollection(self)
self.__restore_lock = asyncio.Lock()
self.__restored = asyncio.Event()
self.close_callbacks.add(self.__close_callback)

self.close_callbacks.remove(self._set_closed_callback)

async def ready(self) -> None:
await self._connection.ready()
Expand All @@ -94,23 +95,31 @@ async def restore(self, channel: Any = None) -> None:
await self.reopen()
self.__restored.set()

async def __close_callback(self, _: Any, exc: BaseException) -> None:
async def _on_close(
self,
closing: asyncio.Future
) -> Optional[BaseException]:
exc = await super()._on_close(closing)

if isinstance(exc, asyncio.CancelledError):
# This happens only if the channel is forced to close from the
# outside, for example, if the connection is closed.
# Of course, here you need to exit from this function
# as soon as possible and to avoid a recovery attempt.
self.__restored.clear()
return
self._closed.set()
return exc

in_restore_state = not self.__restored.is_set()
self.__restored.clear()

if self._closed or in_restore_state:
return
if self._closed.is_set() or in_restore_state:
return exc

await self.restore()

return exc

async def _open(self) -> None:
await super()._open()
await self.reopen_callbacks()
Expand Down
5 changes: 5 additions & 0 deletions aio_pika/robust_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def iterator(self, **kwargs: Any) -> AbstractQueueIterator:


class RobustQueueIterator(QueueIterator):
def __init__(self, queue: Queue, **kwargs: Any):
super().__init__(queue, **kwargs)

self._amqp_queue.close_callbacks.discard(self._set_closed_callback)

async def consume(self) -> None:
while True:
try:
Expand Down
12 changes: 11 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c0e8172

Please sign in to comment.