Skip to content

Commit

Permalink
Use hub.remove_writer instead of hub.remove for write fds (celery#4185)
Browse files Browse the repository at this point in the history
- fix main process Unrecoverable error: AssertionError() when read fd is deleted
- see celery#4185 (comment)
- tests:
    - change hub.remove to hub.remove_writer in test_poll_write_generator and test_poll_write_generator_stopped
    - add 3 more tests for schedule_writes to assert only hub.writers is removed when hub.readers have the same fd id
  • Loading branch information
Idan-vast committed Jun 14, 2024
1 parent cc304b2 commit 4823f57
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 7 deletions.
8 changes: 4 additions & 4 deletions celery/concurrency/asynpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def on_poll_start():
None, WRITE | ERR, consolidate=True)
else:
iterate_file_descriptors_safely(
inactive, all_inqueues, hub_remove)
inactive, all_inqueues, hub.remove_writer)
self.on_poll_start = on_poll_start

def on_inqueue_close(fd, proc):
Expand Down Expand Up @@ -818,7 +818,7 @@ def schedule_writes(ready_fds, total_write_count=None):
# worker is already busy with another task
continue
if ready_fd not in all_inqueues:
hub_remove(ready_fd)
hub.remove_writer(ready_fd)
continue
try:
job = pop_message()
Expand All @@ -829,7 +829,7 @@ def schedule_writes(ready_fds, total_write_count=None):
# this may create a spinloop where the event loop
# always wakes up.
for inqfd in diff(active_writes):
hub_remove(inqfd)
hub.remove_writer(inqfd)
break

else:
Expand Down Expand Up @@ -927,7 +927,7 @@ def _write_job(proc, fd, job):
else:
errors = 0
finally:
hub_remove(fd)
hub.remove_writer(fd)
write_stats[proc.index] += 1
# message written, so this fd is now available
active_writes.discard(fd)
Expand Down
96 changes: 96 additions & 0 deletions t/unit/concurrency/test_prefork.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from unittest.mock import Mock, patch

import pytest
from billiard.pool import ApplyResult
from kombu.asynchronous import Hub

import t.skip
from celery.app.defaults import DEFAULTS
Expand Down Expand Up @@ -354,6 +356,100 @@ def _fake_hub(*args, **kwargs):
# Then: all items were removed from the managed data source
assert fd_iter == {}, "Expected all items removed from managed dict"

def _get_hub(self):
hub = Hub()
hub.readers = {}
hub.writers = {}
hub.timer = Mock(name='hub.timer')
hub.timer._queue = [Mock()]
hub.fire_timers = Mock(name='hub.fire_timers')
hub.fire_timers.return_value = 1.7
hub.poller = Mock(name='hub.poller')
hub.close = Mock(name='hub.close()')
return hub

def test_schedule_writes_hub_remove_writer_ready_fd_not_in_all_inqueues(self):
pool = asynpool.AsynPool(threads=False)
hub = self._get_hub()

writer = Mock(name='writer')
reader = Mock(name='reader')

# add 2 fake fds with the same id
hub.add_reader(6, reader, 6)
hub.add_writer(6, writer, 6)
pool._all_inqueues.clear()
pool._create_write_handlers(hub)

# check schedule_writes write fds remove not remove the reader one from the hub.
hub.consolidate_callback(ready_fds=[6])
assert 6 in hub.readers
assert 6 not in hub.writers

def test_schedule_writes_hub_remove_writers_from_active_writers_when_get_index_error(self):
pool = asynpool.AsynPool(threads=False)
hub = self._get_hub()

writer = Mock(name='writer')
reader = Mock(name='reader')

# add 3 fake fds with the same id to reader and writer
hub.add_reader(6, reader, 6)
hub.add_reader(8, reader, 8)
hub.add_reader(9, reader, 9)
hub.add_writer(6, writer, 6)
hub.add_writer(8, writer, 8)
hub.add_writer(9, writer, 9)

# add fake fd to pool _all_inqueues to make sure we try to read from outbound_buffer
# set active_writes to 6 to make sure we remove all write fds except 6
pool._active_writes = {6}
pool._all_inqueues = {2, 6, 8, 9}

pool._create_write_handlers(hub)

# clear outbound_buffer to get IndexError when trying to pop any message
# in this case all active_writers fds will be removed from the hub
pool.outbound_buffer.clear()

hub.consolidate_callback(ready_fds=[2])
if {6, 8, 9} <= hub.readers.keys() and not {8, 9} <= hub.writers.keys():
assert True
else:
assert False

assert 6 in hub.writers

def test_schedule_writes_hub_remove_fd_only_from_writers_when_write_job_is_done(self):
pool = asynpool.AsynPool(threads=False)
hub = self._get_hub()

writer = Mock(name='writer')
reader = Mock(name='reader')

# add one writer and one reader with the same fd
hub.add_writer(2, writer, 2)
hub.add_reader(2, reader, 2)
assert 2 in hub.writers

# For test purposes to reach _write_job in schedule writes
pool._all_inqueues = {2}
worker = Mock("worker")
# this lambda need to return a number higher than 4
# to pass the while loop in _write_job function and to reach the hub.remove_writer
worker.send_job_offset = lambda header, HW: 5

pool._fileno_to_inq[2] = worker
pool._create_write_handlers(hub)

result = ApplyResult({}, lambda x: True)
result._payload = [None, None, -1]
pool.outbound_buffer.appendleft(result)

hub.consolidate_callback(ready_fds=[2])
assert 2 not in hub.writers
assert 2 in hub.readers

def test_register_with_event_loop__no_on_tick_dupes(self):
"""Ensure AsynPool's register_with_event_loop only registers
on_poll_start in the event loop the first time it's called. This
Expand Down
6 changes: 3 additions & 3 deletions t/unit/worker/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_poll_err_writable(self):

def test_poll_write_generator(self):
x = X(self.app)
x.hub.remove = Mock(name='hub.remove()')
x.hub.remove_writer = Mock(name='hub.remove_writer()')

def Gen():
yield 1
Expand All @@ -376,7 +376,7 @@ def Gen():
with pytest.raises(socket.error):
asynloop(*x.args)
assert gen.gi_frame.f_lasti != -1
x.hub.remove.assert_not_called()
x.hub.remove_writer.assert_not_called()

def test_poll_write_generator_stopped(self):
x = X(self.app)
Expand All @@ -388,7 +388,7 @@ def Gen():
x.hub.add_writer(6, gen)
x.hub.on_tick.add(x.close_then_error(Mock(name='tick'), 2))
x.hub.poller.poll.return_value = [(6, WRITE)]
x.hub.remove = Mock(name='hub.remove()')
x.hub.remove_writer = Mock(name='hub.remove_writer()')
with pytest.raises(socket.error):
asynloop(*x.args)
assert gen.gi_frame is None
Expand Down

0 comments on commit 4823f57

Please sign in to comment.