Skip to content

Commit

Permalink
add py.typed, update accountstreamer to alertstreamer
Browse files Browse the repository at this point in the history
  • Loading branch information
Graeme22 committed Jun 10, 2024
1 parent 0182d31 commit cdce51a
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 60 deletions.
37 changes: 17 additions & 20 deletions docs/account-streamer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,29 @@ Here's an example of setting up an account streamer to continuously wait for eve

.. code-block:: python
from tastytrade import Account, AccountStreamer
from tastytrade import Account, AlertStreamer
from tastytrade.streamer import AlertType
async with AccountStreamer(session) as streamer:
accounts = Account.get_accounts(session)
async with AlertStreamer(session) as streamer:
accounts = Account.get_accounts(session)
# updates to balances, orders, and positions
await streamer.subscribe_accounts(accounts)
# changes in public watchlists
await streamer.subscribe_public_watchlists()
# quote alerts configured by the user
await streamer.subscribe_quote_alerts()
# updates to balances, orders, and positions
await streamer.subscribe_accounts(accounts)
# changes in public watchlists
await streamer.subscribe_public_watchlists()
# quote alerts configured by the user
await streamer.subscribe_quote_alerts()
async for data in streamer.listen():
print(data)
async for watchlist in streamer.listen(AlertType.WATCHLIST):
print(f'Watchlist updated: {watchlist}')
Probably the most important information the account streamer handles is order fills. We can listen just for orders like so:

.. code-block:: python
from tastytrade.order import PlacedOrder
async with AlertStreamer(session) as streamer:
accounts = Account.get_accounts(session)
await streamer.subscribe_accounts(accounts)
async def listen_for_orders(session):
async with AccountStreamer(session) as streamer:
accounts = Account.get_accounts(session)
await streamer.subscribe_accounts(accounts)
async for data in streamer.listen():
if isinstance(data, PlacedOrder):
yield return data
async for order in streamer.listen(AlertType.ORDER):
print(f'Order updated: {order}')
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
project = 'tastytrade'
copyright = '2024, Graeme Holliday'
author = 'Graeme Holliday'
release = '7.4'
release = '7.5'

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name='tastytrade',
version='7.4',
version='7.5',
description='An unofficial SDK for Tastytrade!',
long_description=LONG_DESCRIPTION,
long_description_content_type='text/x-rst',
Expand All @@ -23,5 +23,6 @@
'fake_useragent>=1.5.1'
],
packages=find_packages(exclude=['ez_setup', 'tests*']),
package_data={'tastytrade': ['py.typed']},
include_package_data=True
)
6 changes: 3 additions & 3 deletions tastytrade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

API_URL = 'https://api.tastyworks.com'
CERT_URL = 'https://api.cert.tastyworks.com'
VERSION = '7.4'
VERSION = '7.5'

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

from .account import Account # noqa: E402
from .search import symbol_search # noqa: E402
from .session import CertificationSession, ProductionSession # noqa: E402
from .streamer import AccountStreamer, DXLinkStreamer # noqa: E402
from .streamer import AlertStreamer, DXLinkStreamer # noqa: E402
from .watchlists import PairsWatchlist, Watchlist # noqa: E402

__all__ = [
'Account',
'AccountStreamer',
'AlertStreamer',
'CertificationSession',
'DXLinkStreamer',
'PairsWatchlist',
Expand Down
Empty file added tastytrade/py.typed
Empty file.
108 changes: 75 additions & 33 deletions tastytrade/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,22 @@ class SubscriptionType(str, Enum):
USER_MESSAGE = 'user-message-subscribe'


class AccountStreamer:
class AlertType(str, Enum):
"""
This is an :class:`~enum.Enum` that contains the event types
for the account streamer.
"""
ACCOUNT_BALANCE = 'AccountBalance'
ORDER = 'Order'
ORDER_CHAIN = 'OrderChain'
POSITION = 'CurrentPosition'
QUOTE_ALERT = 'QuoteAlert'
TRADING_STATUS = 'TradingStatus'
UNDERLYING_SUMMARY = 'UnderlyingYearGainSummary'
WATCHLIST = 'PublicWatchlists'


class AlertStreamer:
"""
Used to subscribe to account-level updates (balances, orders, positions),
public watchlist updates, quote alerts, and user-level messages. It should
Expand All @@ -87,9 +102,9 @@ class AccountStreamer:
Example usage::
from tastytrade import Account, AccountStreamer
from tastytrade import Account, AlertStreamer
async with AccountStreamer(session) as streamer:
async with AlertStreamer(session) as streamer:
accounts = Account.get_accounts(session)
# updates to balances, orders, and positions
Expand All @@ -111,7 +126,7 @@ def __init__(self, session: Session):
self.base_url: str = \
CERT_STREAMER_URL if is_certification else STREAMER_URL

self._queue: Queue = Queue()
self._queues: Dict[AlertType, Queue] = defaultdict(Queue)
self._websocket: Optional[WebSocketClientProtocol] = None
self._connect_task = asyncio.create_task(self._connect())

Expand All @@ -126,7 +141,7 @@ async def __aenter__(self):
return self

@classmethod
async def create(cls, session: Session) -> 'AccountStreamer':
async def create(cls, session: Session) -> 'AlertStreamer':
self = cls(session)
return await self.__aenter__()

Expand Down Expand Up @@ -156,46 +171,73 @@ async def _connect(self) -> None:
while True:
raw_message = await self._websocket.recv() # type: ignore
logger.debug('raw message: %s', raw_message)
await self._queue.put(json.loads(raw_message))
data = json.loads(raw_message)
type_str = data.get('type')
if type_str is not None:
await self._map_message(type_str, data['data'])

async def listen(self) -> AsyncIterator[TastytradeJsonDataclass]:
async def listen(
self,
event_type: AlertType
) -> AsyncIterator[
Union[
AccountBalance,
CurrentPosition,
PlacedOrder,
OrderChain,
QuoteAlert,
TradingStatus,
UnderlyingYearGainSummary,
Watchlist
]
]:
"""
Iterate over non-heartbeat messages received from the streamer,
mapping them to their appropriate data class and yielding them.
"""
while True:
data = await self._queue.get()
type_str = data.get('type')
if type_str is not None:
yield self._map_message(type_str, data['data'])
yield await self._queues[event_type].get()

def _map_message(
self,
type_str: str,
data: dict
) -> TastytradeJsonDataclass:
async def _map_message(self, type_str: str, data: dict):
"""
I'm not sure what the user-status messages look like,
so they're absent.
"""
if type_str == 'AccountBalance':
return AccountBalance(**data)
elif type_str == 'CurrentPosition':
return CurrentPosition(**data)
elif type_str == 'Order':
return PlacedOrder(**data)
elif type_str == 'OrderChain':
return OrderChain(**data)
elif type_str == 'QuoteAlert':
return QuoteAlert(**data)
elif type_str == 'TradingStatus':
return TradingStatus(**data)
elif type_str == 'UnderlyingYearGainSummary':
return UnderlyingYearGainSummary(**data)
elif type_str == 'PublicWatchlists':
return Watchlist(**data)
if type_str == AlertType.ACCOUNT_BALANCE:
await self._queues[AlertType.ACCOUNT_BALANCE].put(
AccountBalance(**data)
)
elif type_str == AlertType.POSITION:
await self._queues[AlertType.POSITION].put(
CurrentPosition(**data)
)
elif type_str == AlertType.ORDER:
await self._queues[AlertType.ORDER].put(
PlacedOrder(**data)
)
elif type_str == AlertType.ORDER_CHAIN:
await self._queues[AlertType.ORDER_CHAIN].put(
OrderChain(**data)
)
elif type_str == AlertType.QUOTE_ALERT:
await self._queues[AlertType.QUOTE_ALERT].put(
QuoteAlert(**data)
)
elif type_str == AlertType.TRADING_STATUS:
await self._queues[AlertType.TRADING_STATUS].put(
TradingStatus(**data)
)
elif type_str == AlertType.UNDERLYING_SUMMARY:
await self._queues[AlertType.UNDERLYING_SUMMARY].put(
UnderlyingYearGainSummary(**data)
)
elif type_str == AlertType.WATCHLIST:
await self._queues[AlertType.WATCHLIST].put(
Watchlist(**data)
)
else:
raise TastytradeError(f'Unknown message type: {type_str}\n{data}')
logger.error(f'Unknown message type {type_str}! Please open an '
f'issue.\n{data}')

async def subscribe_accounts(self, accounts: List[Account]) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/test_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import pytest

from tastytrade import Account, AccountStreamer, DXLinkStreamer
from tastytrade import Account, AlertStreamer, DXLinkStreamer
from tastytrade.dxfeed import EventType

pytest_plugins = ('pytest_asyncio',)


@pytest.mark.asyncio
async def test_account_streamer(session):
async with AccountStreamer(session) as streamer:
async with AlertStreamer(session) as streamer:
await streamer.subscribe_public_watchlists()
await streamer.subscribe_quote_alerts()
await streamer.subscribe_user_messages(session)
Expand Down

0 comments on commit cdce51a

Please sign in to comment.