diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index 2db1e4711323..c2bbc75d7383 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -11,20 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import AsyncMock, Mock from twisted.internet import defer -from synapse.replication.tcp.commands import PositionCommand, UserIpCommand +from synapse.replication.tcp.commands import ( + PositionCommand, + RemoteServerUpCommand, + UserIpCommand, +) +from synapse.server import HomeServer +from synapse.util import Clock from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import ThreadedMemoryReactorClock +from tests.unittest import override_config class ChannelsTestCase(BaseMultiWorkerStreamTestCase): def test_subscribed_to_enough_redis_channels(self) -> None: - # The default main process is subscribed to the USER_IP channel. + # The default main process is subscribed to the REMOTE_SERVER_UP and USER_IP + # channel. self.assertCountEqual( self.hs.get_replication_command_handler()._channels_to_subscribe_to, - [UserIpCommand.NAME], + [UserIpCommand.NAME, RemoteServerUpCommand.NAME], ) def test_background_worker_subscribed_to_user_ip(self) -> None: @@ -76,6 +86,59 @@ def test_non_background_worker_not_subscribed_to_user_ip(self) -> None: len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1 ) + @override_config({"federation_sender_instances": ["worker1"]}) + def test_federation_sender_subscribed_to_remote_server_up(self) -> None: + # The default main process and federation senders are subscribed to the + # REMOTE_SERVER_UP channel. + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "redis": {"enabled": True}, + }, + ) + + self.assertIn( + RemoteServerUpCommand.NAME, + worker1.get_replication_command_handler()._channels_to_subscribe_to, + ) + + # Advance so the Redis subscription gets processed + self.pump(0.1) + + # The counts are 2 because both the main process and the worker are subscribed. + self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2) + self.assertEqual( + len(self._redis_server._subscribers_by_channel[b"test/REMOTE_SERVER_UP"]), 2 + ) + + @override_config({"federation_sender_instances": ["worker1"]}) + def test_non_federation_sender_not_subscribed_to_remote_server_up(self) -> None: + # Only the default main process is subscribed to the REMOTE_SERVER_UP channel + # because it is the main process. The override above tells 'send_federation' to + # be false, so main is not a sender. + worker2 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker2", + "redis": {"enabled": True}, + }, + ) + self.assertNotIn( + RemoteServerUpCommand.NAME, + worker2.get_replication_command_handler()._channels_to_subscribe_to, + ) + + # Advance so the Redis subscription gets processed + self.pump(0.1) + + # The count is 2 because both the main process and the worker are subscribed. + self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2) + # For REMOTE_SERVER_UP, the count is 1 because only the main process is subscribed. + self.assertEqual( + len(self._redis_server._subscribers_by_channel[b"test/REMOTE_SERVER_UP"]), 1 + ) + def test_wait_for_stream_position(self) -> None: """Check that wait for stream position correctly waits for an update from the correct instance. @@ -202,3 +265,43 @@ def test_wait_for_stream_position_rdata(self) -> None: # Master should get told about `next_token2`, so the deferred should # resolve. self.assertTrue(d.called) + + +class ChannelsCapabilityTestCase(BaseMultiWorkerStreamTestCase): + def make_homeserver( + self, reactor: ThreadedMemoryReactorClock, clock: Clock + ) -> HomeServer: + self.replication_data_handler = Mock( + spec=["on_remote_server_up", "on_position"] + ) + self.replication_data_handler.on_remote_server_up = Mock() + self.replication_data_handler.on_position = AsyncMock() + + hs = self.setup_test_homeserver( + replication_data_handler=self.replication_data_handler, + ) + return hs + + def test_sending_command_while_not_subscribed(self) -> None: + # Test that sending a command from a worker that is not on a specific channel + # actually allows receiving on a 'worker'(in this case main) that is. + # + # Proves: a worker doesn't have to be subscribed to a channel to send commands. + mock_on_remote_server_up = self.replication_data_handler.on_remote_server_up + + # worker1 is setup to listen on ["test", "test/USER_IP"] and main is on all + worker1 = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "worker_name": "worker1", + "run_background_tasks_on": "worker1", + "redis": {"enabled": True}, + }, + ) + worker_cmd_handler = worker1.get_replication_command_handler() + # Pump the reactor, so the replication connections are established + self.pump() + worker_cmd_handler.send_remote_server_up("test2") + # Actually allow the replication to take place + self.replicate() + mock_on_remote_server_up.assert_called_once()