Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jell-o-fishi committed Oct 20, 2023
1 parent ffd54e4 commit 1e7eb8e
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 42 deletions.
19 changes: 9 additions & 10 deletions rsocket/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from rsocket.exceptions import RSocketProtocolError, ParseError, RSocketUnknownFrameType
from rsocket.fragment import Fragment
from rsocket.frame_fragmenter import data_to_fragments_if_required
from rsocket.frame_helpers import unpack_position, pack_position, unpack_24bit, pack_24bit, unpack_32bit, \
ensure_bytes, pack_string, unpack_string
from rsocket.frame_helpers import (is_flag_set, unpack_position, pack_position, unpack_24bit, pack_24bit, unpack_32bit,
ensure_bytes, pack_string, unpack_string)
from rsocket.logger import logger

PROTOCOL_MAJOR_VERSION = 1
Expand Down Expand Up @@ -96,9 +96,6 @@ class ParseHelper:
parse_header: Callable = None


from rsocket.frame_helpers import is_flag_set


def parse_header_native(frame: Header, buffer: bytes, offset: int) -> Flags:
frame.length = len(buffer)
frame.stream_id, frame.frame_type, flag_bits = struct.unpack_from('>IBB', buffer, offset)
Expand Down Expand Up @@ -780,11 +777,13 @@ def is_fragmentable_frame(frame: Frame) -> bool:
))


FragmentableFrame = Union[PayloadFrame,
RequestResponseFrame,
RequestChannelFrame,
RequestStreamFrame,
RequestFireAndForgetFrame]
FragmentableFrame = Union[
PayloadFrame,
RequestResponseFrame,
RequestChannelFrame,
RequestStreamFrame,
RequestFireAndForgetFrame
]


def new_frame_fragment(base_frame: FragmentableFrame, fragment: Fragment) -> Frame:
Expand Down
2 changes: 1 addition & 1 deletion rsocket/frame_fragment_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _frame_fragment_builder(self, next_fragment: FragmentableFrame) -> Fragmenta

current_frame_from_fragments = self._frames_by_stream_id.get(next_fragment.stream_id)

if current_frame_from_fragments is not None and type(next_fragment) != PayloadFrame:
if current_frame_from_fragments is not None and not isinstance(next_fragment, PayloadFrame):
raise RSocketFrameFragmentDifferentType()

if current_frame_from_fragments is None:
Expand Down
12 changes: 10 additions & 2 deletions rsocket/handlers/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@
class Requester:
pass
import abc

from rsocket.frame import Frame


class Requester(metaclass=abc.ABCMeta):

@abc.abstractmethod
def frame_received(self, frame: Frame):
...
16 changes: 8 additions & 8 deletions rsocket/reactivex/from_rsocket_publisher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import functools
from asyncio import Event, CancelledError, get_event_loop, create_task

import reactivex
from reactivex import Observable, Observer
Expand All @@ -17,8 +17,8 @@ def __init__(self, observer: Observer, limit_rate: int = MAX_REQUEST_N):
self.limit_rate = limit_rate
self.observer = observer
self._received_messages = 0
self.done = asyncio.Event()
self.get_next_n = asyncio.Event()
self.done = Event()
self.get_next_n = Event()
self.subscription = None

def on_subscribe(self, subscription: Subscription):
Expand Down Expand Up @@ -54,7 +54,7 @@ async def _aio_sub(publisher: Publisher, subscriber: RxSubscriber, observer: Obs
publisher.subscribe(subscriber)
await subscriber.done.wait()

except asyncio.CancelledError:
except CancelledError:
if not subscriber.done.is_set():
subscriber.subscription.cancel()
except Exception as exception:
Expand All @@ -67,21 +67,21 @@ async def _trigger_next_request_n(subscriber: RxSubscriber, limit_rate):
await subscriber.get_next_n.wait()
subscriber.subscription.request(limit_rate)
subscriber.get_next_n.clear()
except asyncio.CancelledError:
except CancelledError:
logger().debug('Asyncio task canceled: trigger_next_request_n')


def from_rsocket_publisher(publisher: Publisher, limit_rate: int = MAX_REQUEST_N) -> Observable:
loop = asyncio.get_event_loop()
loop = get_event_loop()

# noinspection PyUnusedLocal
def on_subscribe(observer: Observer, scheduler):
subscriber = RxSubscriber(observer, limit_rate)

get_next_task = asyncio.create_task(
get_next_task = create_task(
_trigger_next_request_n(subscriber, limit_rate)
)
task = asyncio.create_task(
task = create_task(
_aio_sub(publisher, subscriber, observer, loop)
)

Expand Down
5 changes: 2 additions & 3 deletions rsocket/reactivex/reactivex_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from asyncio import Future
from asyncio import Future, Event
from typing import Optional, cast, Union, Callable

import reactivex
Expand Down Expand Up @@ -30,7 +29,7 @@ def request_channel(self,
request: Payload,
request_limit: int = MAX_REQUEST_N,
observable: Optional[Union[Observable, Callable[[Subject], Observable]]] = None,
sending_done: Optional[asyncio.Event] = None) -> Observable:
sending_done: Optional[Event] = None) -> Observable:
requester_publisher = observable_to_publisher(observable)

response_publisher = self._rsocket.request_channel(
Expand Down
2 changes: 1 addition & 1 deletion rsocket/rsocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class RSocketClient(RSocketBase):
"""
Client side instance of an RSocket connection.
:param transport_provider: Async generator which returns `Transport` to use with this instance. See `Transport` class implementations.
:param transport_provider: Async generator which returns `Transport` to use with this instance.
:param request_queue_size: Number of frames which can be queued while waiting for a lease.
:param fragment_size_bytes: Minimum 64, Maximum depends on transport.
"""
Expand Down
17 changes: 8 additions & 9 deletions rsocket/rx_support/back_pressure_publisher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from asyncio import Queue
from asyncio import Queue, create_task, sleep
from dataclasses import dataclass
from typing import Optional, Callable, AsyncGenerator, Union

Expand Down Expand Up @@ -50,16 +49,16 @@ async def task_from_awaitable(future):
async def coroutine_from_awaitable(awaitable):
return await awaitable

task = asyncio.create_task(coroutine_from_awaitable(future))
await asyncio.sleep(0) # allow awaitable to be accessed at least once
task = create_task(coroutine_from_awaitable(future))
await sleep(0) # allow awaitable to be accessed at least once
return task


def observable_from_async_generator(iterator, backpressure: Subject) -> Observable:
# noinspection PyUnusedLocal
def on_subscribe(observer: Observer, scheduler):

request_n_queue = asyncio.Queue()
request_n_queue = Queue()

def request_next_n(n):
request_n_queue.put_nowait(n)
Expand Down Expand Up @@ -92,15 +91,15 @@ def cancel_sender():
on_completed=cancel_sender
)

sender = asyncio.create_task(_aio_next())
sender = create_task(_aio_next())

return result

return rx.create(on_subscribe)


async def observable_to_async_event_generator(observable: Observable) -> AsyncGenerator[Notification, None]:
queue = asyncio.Queue()
queue = Queue()

completed = object()

Expand Down Expand Up @@ -131,7 +130,7 @@ def from_async_event_iterator(iterator, backpressure: Subject) -> Observable:
# noinspection PyUnusedLocal
def on_subscribe(observer: Observer, scheduler):

request_n_queue = asyncio.Queue()
request_n_queue = Queue()

async def _aio_next():

Expand All @@ -156,7 +155,7 @@ async def _aio_next():
logger().error(str(exception), exc_info=True)
observer.on_error(exception)

sender = asyncio.create_task(_aio_next())
sender = create_task(_aio_next())

def cancel_sender():
sender.cancel()
Expand Down
16 changes: 8 additions & 8 deletions rsocket/rx_support/from_rsocket_publisher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
import functools
from asyncio import Event, CancelledError, get_event_loop, create_task

import rx
from rx import Observable
Expand All @@ -18,8 +18,8 @@ def __init__(self, observer: Observer, limit_rate: int = MAX_REQUEST_N):
self.limit_rate = limit_rate
self.observer = observer
self._received_messages = 0
self.done = asyncio.Event()
self.get_next_n = asyncio.Event()
self.done = Event()
self.get_next_n = Event()
self.subscription = None

def on_subscribe(self, subscription: Subscription):
Expand Down Expand Up @@ -55,7 +55,7 @@ async def _aio_sub(publisher: Publisher, subscriber: RxSubscriber, observer: Obs
publisher.subscribe(subscriber)
await subscriber.done.wait()

except asyncio.CancelledError:
except CancelledError:
if not subscriber.done.is_set():
subscriber.subscription.cancel()
except Exception as exception:
Expand All @@ -68,21 +68,21 @@ async def _trigger_next_request_n(subscriber: RxSubscriber, limit_rate):
await subscriber.get_next_n.wait()
subscriber.subscription.request(limit_rate)
subscriber.get_next_n.clear()
except asyncio.CancelledError:
except CancelledError:
logger().debug('Asyncio task canceled: trigger_next_request_n')


def from_rsocket_publisher(publisher: Publisher, limit_rate: int = MAX_REQUEST_N) -> Observable:
loop = asyncio.get_event_loop()
loop = get_event_loop()

# noinspection PyUnusedLocal
def on_subscribe(observer: Observer, scheduler):
subscriber = RxSubscriber(observer, limit_rate)

get_next_task = asyncio.create_task(
get_next_task = create_task(
_trigger_next_request_n(subscriber, limit_rate)
)
task = asyncio.create_task(
task = create_task(
_aio_sub(publisher, subscriber, observer, loop)
)

Expand Down

0 comments on commit 1e7eb8e

Please sign in to comment.