Skip to content

Commit

Permalink
PubSub: Allow specifying custom encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Ododo committed Jan 17, 2022
1 parent b785498 commit af08e53
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 30 deletions.
7 changes: 4 additions & 3 deletions src/socketio/asyncio_aiopika_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover
name = 'asyncaiopika'

def __init__(self, url='amqp://guest:guest@localhost:5672//',
channel='socketio', write_only=False, logger=None):
channel='socketio', write_only=False, logger=None,
encoder=pickle):
if aio_pika is None:
raise RuntimeError('aio_pika package is not installed '
'(Run "pip install aio_pika" in your '
Expand Down Expand Up @@ -70,7 +71,7 @@ async def _publish(self, data):
channel = await self._channel(connection)
exchange = await self._exchange(channel)
await exchange.publish(
aio_pika.Message(body=pickle.dumps(data),
aio_pika.Message(body=self.encoder.dumps(data),
delivery_mode=aio_pika.DeliveryMode.PERSISTENT),
routing_key='*'
)
Expand All @@ -94,7 +95,7 @@ async def _listen(self):
async with self.listener_queue.iterator() as queue_iter:
async for message in queue_iter:
with message.process():
yield pickle.loads(message.body)
yield message.body
except Exception:
self._get_logger().error('Cannot receive from rabbitmq... '
'retrying in '
Expand Down
12 changes: 10 additions & 2 deletions src/socketio/asyncio_pubsub_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class AsyncPubSubManager(AsyncManager):
"""
name = 'asyncpubsub'

def __init__(self, channel='socketio', write_only=False, logger=None):
def __init__(self, channel='socketio', write_only=False, logger=None,
encoder=pickle):
super().__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
self.encoder = encoder

def initialize(self):
super().initialize()
Expand Down Expand Up @@ -153,7 +155,13 @@ async def _thread(self):
if isinstance(message, dict):
data = message
else:
if isinstance(message, bytes): # pragma: no cover
if self.encoder:
try:
data = self.encoder.loads(message)
except:
pass
if data is None and \
isinstance(message, bytes): # pragma: no cover
try:
data = pickle.loads(message)
except:
Expand Down
10 changes: 7 additions & 3 deletions src/socketio/asyncio_redis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover
and receiving.
:param redis_options: additional keyword arguments to be passed to
``aioredis.from_url()``.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'aioredis'

def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None, redis_options=None):
write_only=False, logger=None, redis_options=None,
encoder=pickle):
if aioredis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install aioredis" in your '
Expand All @@ -46,7 +49,8 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio',
self.redis_url = url
self.redis_options = redis_options or {}
self._redis_connect()
super().__init__(channel=channel, write_only=write_only, logger=logger)
super().__init__(channel=channel, write_only=write_only, logger=logger,
encoder=encoder)

def _redis_connect(self):
self.redis = aioredis.Redis.from_url(self.redis_url,
Expand All @@ -60,7 +64,7 @@ async def _publish(self, data):
if not retry:
self._redis_connect()
return await self.redis.publish(
self.channel, pickle.dumps(data))
self.channel, self.encoder.dumps(data))
except aioredis.exceptions.RedisError:
if retry:
self._get_logger().error('Cannot publish to redis... '
Expand Down
11 changes: 7 additions & 4 deletions src/socketio/kafka_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,21 @@ class KafkaManager(PubSubManager): # pragma: no cover
:param write_only: If set to ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'kafka'

def __init__(self, url='kafka://localhost:9092', channel='socketio',
write_only=False):
write_only=False, encoder=pickle):
if kafka is None:
raise RuntimeError('kafka-python package is not installed '
'(Run "pip install kafka-python" in your '
'virtualenv).')

super(KafkaManager, self).__init__(channel=channel,
write_only=write_only)
write_only=write_only,
encoder=encoder)

urls = [url] if isinstance(url, str) else url
self.kafka_urls = [url[8:] if url != 'kafka://' else 'localhost:9092'
Expand All @@ -54,7 +57,7 @@ def __init__(self, url='kafka://localhost:9092', channel='socketio',
bootstrap_servers=self.kafka_urls)

def _publish(self, data):
self.producer.send(self.channel, value=pickle.dumps(data))
self.producer.send(self.channel, value=self.encoder.dumps(data))
self.producer.flush()

def _kafka_listen(self):
Expand All @@ -64,4 +67,4 @@ def _kafka_listen(self):
def _listen(self):
for message in self._kafka_listen():
if message.topic == self.channel:
yield pickle.loads(message.value)
yield message.value
10 changes: 7 additions & 3 deletions src/socketio/kombu_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,24 @@ class KombuManager(PubSubManager): # pragma: no cover
``kombu.Queue()``.
:param producer_options: additional keyword arguments to be passed to
``kombu.Producer()``.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'kombu'

def __init__(self, url='amqp://guest:guest@localhost:5672//',
channel='socketio', write_only=False, logger=None,
connection_options=None, exchange_options=None,
queue_options=None, producer_options=None):
queue_options=None, producer_options=None,
encoder=pickle):
if kombu is None:
raise RuntimeError('Kombu package is not installed '
'(Run "pip install kombu" in your '
'virtualenv).')
super(KombuManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
logger=logger,
encoder=encoder)
self.url = url
self.connection_options = connection_options or {}
self.exchange_options = exchange_options or {}
Expand Down Expand Up @@ -103,7 +107,7 @@ def _publish(self, data):
connection = self._connection()
publish = connection.ensure(self.producer, self.producer.publish,
errback=self.__error_callback)
publish(pickle.dumps(data))
publish(self.encoder.dumps(data))

def _listen(self):
reader_queue = self._queue()
Expand Down
12 changes: 10 additions & 2 deletions src/socketio/pubsub_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ class PubSubManager(BaseManager):
"""
name = 'pubsub'

def __init__(self, channel='socketio', write_only=False, logger=None):
def __init__(self, channel='socketio', write_only=False, logger=None,
encoder=None):
super(PubSubManager, self).__init__()
self.channel = channel
self.write_only = write_only
self.host_id = uuid.uuid4().hex
self.logger = logger
self.encoder = encoder

def initialize(self):
super(PubSubManager, self).initialize()
Expand Down Expand Up @@ -151,7 +153,13 @@ def _thread(self):
if isinstance(message, dict):
data = message
else:
if isinstance(message, bytes): # pragma: no cover
if self.encoder:
try:
data = self.encoder.loads(message)
except:
pass
if data is None and \
isinstance(message, bytes): # pragma: no cover
try:
data = pickle.loads(message)
except:
Expand Down
11 changes: 8 additions & 3 deletions src/socketio/redis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ class RedisManager(PubSubManager): # pragma: no cover
and receiving.
:param redis_options: additional keyword arguments to be passed to
``Redis.from_url()``.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
"""
name = 'redis'

def __init__(self, url='redis://localhost:6379/0', channel='socketio',
write_only=False, logger=None, redis_options=None):
write_only=False, logger=None, redis_options=None,
encoder=pickle):
if redis is None:
raise RuntimeError('Redis package is not installed '
'(Run "pip install redis" in your '
Expand All @@ -50,7 +53,8 @@ def __init__(self, url='redis://localhost:6379/0', channel='socketio',
self._redis_connect()
super(RedisManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
logger=logger,
encoder=encoder)

def initialize(self):
super(RedisManager, self).initialize()
Expand Down Expand Up @@ -78,7 +82,8 @@ def _publish(self, data):
try:
if not retry:
self._redis_connect()
return self.redis.publish(self.channel, pickle.dumps(data))
return self.redis.publish(self.channel,
self.encoder.dumps(data))
except redis.exceptions.RedisError:
if retry:
logger.error('Cannot publish to redis... retrying')
Expand Down
17 changes: 9 additions & 8 deletions src/socketio/zmq_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class ZmqManager(PubSubManager): # pragma: no cover
:param write_only: If set to ``True``, only initialize to emit events. The
default of ``False`` initializes the class for emitting
and receiving.
:param encoder: The encoder to use for publishing and decoding data,
defaults to pickle.
A zmq message broker must be running for the zmq_manager to work.
you can write your own or adapt one from the following simple broker
Expand All @@ -50,7 +52,8 @@ class ZmqManager(PubSubManager): # pragma: no cover
def __init__(self, url='zmq+tcp://localhost:5555+5556',
channel='socketio',
write_only=False,
logger=None):
logger=None,
encoder=pickle):
if zmq is None:
raise RuntimeError('zmq package is not installed '
'(Run "pip install pyzmq" in your '
Expand All @@ -77,17 +80,18 @@ def __init__(self, url='zmq+tcp://localhost:5555+5556',
self.channel = channel
super(ZmqManager, self).__init__(channel=channel,
write_only=write_only,
logger=logger)
logger=logger,
encoder=encoder)

def _publish(self, data):
pickled_data = pickle.dumps(
encoded_data = self.encoder.dumps(
{
'type': 'message',
'channel': self.channel,
'data': data
}
)
return self.sink.send(pickled_data)
return self.sink.send(encoded_data)

def zmq_listen(self):
while True:
Expand All @@ -98,10 +102,7 @@ def zmq_listen(self):
def _listen(self):
for message in self.zmq_listen():
if isinstance(message, bytes):
try:
message = pickle.loads(message)
except Exception:
pass
yield message
if isinstance(message, dict) and \
message['type'] == 'message' and \
message['channel'] == self.channel and \
Expand Down
48 changes: 46 additions & 2 deletions tests/common/test_pubsub_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import functools
import logging
import pickle
import json
import marshal
import unittest
from unittest import mock

Expand Down Expand Up @@ -365,8 +368,6 @@ def test_background_thread(self):
self.pm._handle_close_room = mock.MagicMock()

def messages():
import pickle

yield {'method': 'emit', 'value': 'foo'}
yield {'missing': 'method'}
yield '{"method": "callback", "value": "bar"}'
Expand Down Expand Up @@ -394,3 +395,46 @@ def messages():
self.pm._handle_close_room.assert_called_once_with(
{'method': 'close_room', 'value': 'baz'}
)

def test_background_thread_with_encoder(self):
mock_server = mock.MagicMock()
pm = pubsub_manager.PubSubManager(encoder=marshal)
pm.set_server(mock_server)
pm._publish = mock.MagicMock()
pm._handle_emit = mock.MagicMock()
pm._handle_callback = mock.MagicMock()
pm._handle_disconnect = mock.MagicMock()
pm._handle_close_room = mock.MagicMock()

pm.initialize()

def messages():
yield {'method': 'emit', 'value': 'foo'}
yield marshal.dumps({'method': 'callback', 'value': 'bar'})
yield json.dumps(
{'method': 'disconnect', 'sid': '123', 'namespace': '/foo'}
)
yield pickle.dumps({'method': 'close_room', 'value': 'baz'})
yield {'method': 'bogus'}
yield 'bad json'
yield b'bad encoding'

pm._listen = mock.MagicMock(side_effect=messages)

try:
pm._thread()
except StopIteration:
pass

pm._handle_emit.assert_called_once_with(
{'method': 'emit', 'value': 'foo'}
)
pm._handle_callback.assert_called_once_with(
{'method': 'callback', 'value': 'bar'}
)
pm._handle_disconnect.assert_called_once_with(
{'method': 'disconnect', 'sid': '123', 'namespace': '/foo'}
)
pm._handle_close_room.assert_called_once_with(
{'method': 'close_room', 'value': 'baz'}
)

0 comments on commit af08e53

Please sign in to comment.