From 4239331c0ac7a2fdef32220f0bdd62f3b8709ae2 Mon Sep 17 00:00:00 2001 From: Olivier Dautricourt Date: Sat, 15 Jan 2022 10:05:32 +0100 Subject: [PATCH] PubSub: Allow specifying custom encoder --- src/socketio/asyncio_aiopika_manager.py | 7 ++-- src/socketio/asyncio_pubsub_manager.py | 12 +++++-- src/socketio/asyncio_redis_manager.py | 10 ++++-- src/socketio/kafka_manager.py | 11 +++--- src/socketio/kombu_manager.py | 10 ++++-- src/socketio/pubsub_manager.py | 12 +++++-- src/socketio/redis_manager.py | 11 ++++-- src/socketio/zmq_manager.py | 17 ++++----- tests/common/test_pubsub_manager.py | 48 +++++++++++++++++++++++-- 9 files changed, 108 insertions(+), 30 deletions(-) diff --git a/src/socketio/asyncio_aiopika_manager.py b/src/socketio/asyncio_aiopika_manager.py index 96dcec65..4aea479c 100644 --- a/src/socketio/asyncio_aiopika_manager.py +++ b/src/socketio/asyncio_aiopika_manager.py @@ -38,7 +38,8 @@ class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover name = 'asyncaiopika' def __init__(self, url='amqp://guest:guest@localhost:5672//', - channel='socketio', write_only=False, logger=None): + channel='socketio', write_only=False, logger=None, + encoder=pickle): if aio_pika is None: raise RuntimeError('aio_pika package is not installed ' '(Run "pip install aio_pika" in your ' @@ -70,7 +71,7 @@ async def _publish(self, data): channel = await self._channel(connection) exchange = await self._exchange(channel) await exchange.publish( - aio_pika.Message(body=pickle.dumps(data), + aio_pika.Message(body=self.encoder.dumps(data), delivery_mode=aio_pika.DeliveryMode.PERSISTENT), routing_key='*' ) @@ -94,7 +95,7 @@ async def _listen(self): async with self.listener_queue.iterator() as queue_iter: async for message in queue_iter: with message.process(): - yield pickle.loads(message.body) + yield message.body except Exception: self._get_logger().error('Cannot receive from rabbitmq... ' 'retrying in ' diff --git a/src/socketio/asyncio_pubsub_manager.py b/src/socketio/asyncio_pubsub_manager.py index ff37f2df..2c09dcc5 100644 --- a/src/socketio/asyncio_pubsub_manager.py +++ b/src/socketio/asyncio_pubsub_manager.py @@ -24,12 +24,14 @@ class AsyncPubSubManager(AsyncManager): """ name = 'asyncpubsub' - def __init__(self, channel='socketio', write_only=False, logger=None): + def __init__(self, channel='socketio', write_only=False, logger=None, + encoder=pickle): super().__init__() self.channel = channel self.write_only = write_only self.host_id = uuid.uuid4().hex self.logger = logger + self.encoder = encoder def initialize(self): super().initialize() @@ -153,7 +155,13 @@ async def _thread(self): if isinstance(message, dict): data = message else: - if isinstance(message, bytes): # pragma: no cover + if self.encoder: + try: + data = self.encoder.loads(message) + except: + pass + if data is None and \ + isinstance(message, bytes): # pragma: no cover try: data = pickle.loads(message) except: diff --git a/src/socketio/asyncio_redis_manager.py b/src/socketio/asyncio_redis_manager.py index d9da5f9a..07fc8d88 100644 --- a/src/socketio/asyncio_redis_manager.py +++ b/src/socketio/asyncio_redis_manager.py @@ -32,11 +32,14 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover and receiving. :param redis_options: additional keyword arguments to be passed to ``aioredis.from_url()``. + :param encoder: The encoder to use for publishing and decoding data, + defaults to pickle. """ name = 'aioredis' def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False, logger=None, redis_options=None): + write_only=False, logger=None, redis_options=None, + encoder=pickle): if aioredis is None: raise RuntimeError('Redis package is not installed ' '(Run "pip install aioredis" in your ' @@ -46,7 +49,8 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio', self.redis_url = url self.redis_options = redis_options or {} self._redis_connect() - super().__init__(channel=channel, write_only=write_only, logger=logger) + super().__init__(channel=channel, write_only=write_only, logger=logger, + encoder=encoder) def _redis_connect(self): self.redis = aioredis.Redis.from_url(self.redis_url, @@ -60,7 +64,7 @@ async def _publish(self, data): if not retry: self._redis_connect() return await self.redis.publish( - self.channel, pickle.dumps(data)) + self.channel, self.encoder.dumps(data)) except aioredis.exceptions.RedisError: if retry: self._get_logger().error('Cannot publish to redis... ' diff --git a/src/socketio/kafka_manager.py b/src/socketio/kafka_manager.py index 739871a3..191882c1 100644 --- a/src/socketio/kafka_manager.py +++ b/src/socketio/kafka_manager.py @@ -33,18 +33,21 @@ class KafkaManager(PubSubManager): # pragma: no cover :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. + :param encoder: The encoder to use for publishing and decoding data, + defaults to pickle. """ name = 'kafka' def __init__(self, url='kafka://localhost:9092', channel='socketio', - write_only=False): + write_only=False, encoder=pickle): if kafka is None: raise RuntimeError('kafka-python package is not installed ' '(Run "pip install kafka-python" in your ' 'virtualenv).') super(KafkaManager, self).__init__(channel=channel, - write_only=write_only) + write_only=write_only, + encoder=encoder) urls = [url] if isinstance(url, str) else url self.kafka_urls = [url[8:] if url != 'kafka://' else 'localhost:9092' @@ -54,7 +57,7 @@ def __init__(self, url='kafka://localhost:9092', channel='socketio', bootstrap_servers=self.kafka_urls) def _publish(self, data): - self.producer.send(self.channel, value=pickle.dumps(data)) + self.producer.send(self.channel, value=self.encoder.dumps(data)) self.producer.flush() def _kafka_listen(self): @@ -64,4 +67,4 @@ def _kafka_listen(self): def _listen(self): for message in self._kafka_listen(): if message.topic == self.channel: - yield pickle.loads(message.value) + yield message.value diff --git a/src/socketio/kombu_manager.py b/src/socketio/kombu_manager.py index 61eebd00..5e21cb7e 100644 --- a/src/socketio/kombu_manager.py +++ b/src/socketio/kombu_manager.py @@ -42,20 +42,24 @@ class KombuManager(PubSubManager): # pragma: no cover ``kombu.Queue()``. :param producer_options: additional keyword arguments to be passed to ``kombu.Producer()``. + :param encoder: The encoder to use for publishing and decoding data, + defaults to pickle. """ name = 'kombu' def __init__(self, url='amqp://guest:guest@localhost:5672//', channel='socketio', write_only=False, logger=None, connection_options=None, exchange_options=None, - queue_options=None, producer_options=None): + queue_options=None, producer_options=None, + encoder=pickle): if kombu is None: raise RuntimeError('Kombu package is not installed ' '(Run "pip install kombu" in your ' 'virtualenv).') super(KombuManager, self).__init__(channel=channel, write_only=write_only, - logger=logger) + logger=logger, + encoder=encoder) self.url = url self.connection_options = connection_options or {} self.exchange_options = exchange_options or {} @@ -103,7 +107,7 @@ def _publish(self, data): connection = self._connection() publish = connection.ensure(self.producer, self.producer.publish, errback=self.__error_callback) - publish(pickle.dumps(data)) + publish(self.encoder.dumps(data)) def _listen(self): reader_queue = self._queue() diff --git a/src/socketio/pubsub_manager.py b/src/socketio/pubsub_manager.py index 9b6f36de..25d64160 100644 --- a/src/socketio/pubsub_manager.py +++ b/src/socketio/pubsub_manager.py @@ -23,12 +23,14 @@ class PubSubManager(BaseManager): """ name = 'pubsub' - def __init__(self, channel='socketio', write_only=False, logger=None): + def __init__(self, channel='socketio', write_only=False, logger=None, + encoder=None): super(PubSubManager, self).__init__() self.channel = channel self.write_only = write_only self.host_id = uuid.uuid4().hex self.logger = logger + self.encoder = encoder def initialize(self): super(PubSubManager, self).initialize() @@ -151,7 +153,13 @@ def _thread(self): if isinstance(message, dict): data = message else: - if isinstance(message, bytes): # pragma: no cover + if self.encoder: + try: + data = self.encoder.loads(message) + except: + pass + if data is None and \ + isinstance(message, bytes): # pragma: no cover try: data = pickle.loads(message) except: diff --git a/src/socketio/redis_manager.py b/src/socketio/redis_manager.py index ab40739e..129bc093 100644 --- a/src/socketio/redis_manager.py +++ b/src/socketio/redis_manager.py @@ -36,11 +36,14 @@ class RedisManager(PubSubManager): # pragma: no cover and receiving. :param redis_options: additional keyword arguments to be passed to ``Redis.from_url()``. + :param encoder: The encoder to use for publishing and decoding data, + defaults to pickle. """ name = 'redis' def __init__(self, url='redis://localhost:6379/0', channel='socketio', - write_only=False, logger=None, redis_options=None): + write_only=False, logger=None, redis_options=None, + encoder=pickle): if redis is None: raise RuntimeError('Redis package is not installed ' '(Run "pip install redis" in your ' @@ -50,7 +53,8 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio', self._redis_connect() super(RedisManager, self).__init__(channel=channel, write_only=write_only, - logger=logger) + logger=logger, + encoder=encoder) def initialize(self): super(RedisManager, self).initialize() @@ -78,7 +82,8 @@ def _publish(self, data): try: if not retry: self._redis_connect() - return self.redis.publish(self.channel, pickle.dumps(data)) + return self.redis.publish(self.channel, + self.encoder.dumps(data)) except redis.exceptions.RedisError: if retry: logger.error('Cannot publish to redis... retrying') diff --git a/src/socketio/zmq_manager.py b/src/socketio/zmq_manager.py index 54538cf1..770d1386 100644 --- a/src/socketio/zmq_manager.py +++ b/src/socketio/zmq_manager.py @@ -29,6 +29,8 @@ class ZmqManager(PubSubManager): # pragma: no cover :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. + :param encoder: The encoder to use for publishing and decoding data, + defaults to pickle. A zmq message broker must be running for the zmq_manager to work. you can write your own or adapt one from the following simple broker @@ -50,7 +52,8 @@ class ZmqManager(PubSubManager): # pragma: no cover def __init__(self, url='zmq+tcp://localhost:5555+5556', channel='socketio', write_only=False, - logger=None): + logger=None, + encoder=pickle): if zmq is None: raise RuntimeError('zmq package is not installed ' '(Run "pip install pyzmq" in your ' @@ -77,17 +80,18 @@ def __init__(self, url='zmq+tcp://localhost:5555+5556', self.channel = channel super(ZmqManager, self).__init__(channel=channel, write_only=write_only, - logger=logger) + logger=logger, + encoder=encoder) def _publish(self, data): - pickled_data = pickle.dumps( + encoded_data = self.encoder.dumps( { 'type': 'message', 'channel': self.channel, 'data': data } ) - return self.sink.send(pickled_data) + return self.sink.send(encoded_data) def zmq_listen(self): while True: @@ -98,10 +102,7 @@ def zmq_listen(self): def _listen(self): for message in self.zmq_listen(): if isinstance(message, bytes): - try: - message = pickle.loads(message) - except Exception: - pass + yield message if isinstance(message, dict) and \ message['type'] == 'message' and \ message['channel'] == self.channel and \ diff --git a/tests/common/test_pubsub_manager.py b/tests/common/test_pubsub_manager.py index 066349f0..ec47ef49 100644 --- a/tests/common/test_pubsub_manager.py +++ b/tests/common/test_pubsub_manager.py @@ -1,5 +1,8 @@ import functools import logging +import pickle +import json +import marshal import unittest from unittest import mock @@ -365,8 +368,6 @@ def test_background_thread(self): self.pm._handle_close_room = mock.MagicMock() def messages(): - import pickle - yield {'method': 'emit', 'value': 'foo'} yield {'missing': 'method'} yield '{"method": "callback", "value": "bar"}' @@ -394,3 +395,46 @@ def messages(): self.pm._handle_close_room.assert_called_once_with( {'method': 'close_room', 'value': 'baz'} ) + + def test_background_thread_with_encoder(self): + mock_server = mock.MagicMock() + pm = pubsub_manager.PubSubManager(encoder=marshal) + pm.set_server(mock_server) + pm._publish = mock.MagicMock() + pm._handle_emit = mock.MagicMock() + pm._handle_callback = mock.MagicMock() + pm._handle_disconnect = mock.MagicMock() + pm._handle_close_room = mock.MagicMock() + + pm.initialize() + + def messages(): + yield {'method': 'emit', 'value': 'foo'} + yield marshal.dumps({'method': 'callback', 'value': 'bar'}) + yield json.dumps( + {'method': 'disconnect', 'sid': '123', 'namespace': '/foo'} + ) + yield pickle.dumps({'method': 'close_room', 'value': 'baz'}) + yield {'method': 'bogus'} + yield 'bad json' + yield b'bad encoding' + + pm._listen = mock.MagicMock(side_effect=messages) + + try: + pm._thread() + except StopIteration: + pass + + pm._handle_emit.assert_called_once_with( + {'method': 'emit', 'value': 'foo'} + ) + pm._handle_callback.assert_called_once_with( + {'method': 'callback', 'value': 'bar'} + ) + pm._handle_disconnect.assert_called_once_with( + {'method': 'disconnect', 'sid': '123', 'namespace': '/foo'} + ) + pm._handle_close_room.assert_called_once_with( + {'method': 'close_room', 'value': 'baz'} + )