diff --git a/src/aleph/sdk/base.py b/src/aleph/sdk/base.py index a5b2c266..9d80bc64 100644 --- a/src/aleph/sdk/base.py +++ b/src/aleph/sdk/base.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime from pathlib import Path from typing import ( Any, @@ -26,42 +25,33 @@ from aleph_message.models.execution.program import Encoding from aleph_message.status import MessageStatus -from aleph.sdk.models import PostsResponse -from aleph.sdk.types import GenericMessage, StorageEnum +from .models.message import MessageFilter +from .models.post import PostFilter, PostsResponse +from .types import GenericMessage, StorageEnum DEFAULT_PAGE_SIZE = 200 class BaseAlephClient(ABC): @abstractmethod - async def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: """ Fetch a value from the aggregate store by owner address and item key. :param address: Address of the owner of the aggregate :param key: Key of the aggregate - :param limit: Maximum number of items to fetch (Default: 100) """ pass @abstractmethod async def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, + self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: """ Fetch key-value pairs from the aggregate store by owner address. :param address: Address of the owner of the aggregate :param keys: Keys of the aggregates to fetch (Default: all items) - :param limit: Maximum number of items to fetch (Default: 100) """ pass @@ -70,15 +60,7 @@ async def get_posts( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -87,15 +69,7 @@ async def get_posts( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -103,44 +77,20 @@ async def get_posts( async def get_posts_iterator( self, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> AsyncIterable[PostMessage]: """ Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates but will always return all posts. - :param types: Types of posts to fetch (Default: all types) - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Chains of the posts to fetch (Default: all chains) - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param post_filter: Filter to apply to the posts (Default: None) """ page = 1 resp = None while resp is None or len(resp.posts) > 0: resp = await self.get_posts( page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) page += 1 for post in resp.posts: @@ -165,18 +115,7 @@ async def get_messages( self, pagination: int = DEFAULT_PAGE_SIZE, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -185,18 +124,7 @@ async def get_messages( :param pagination: Number of items to fetch (Default: 200) :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: [DEPRECATED] Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param message_types: Filter by message types, can be any combination of "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by aggregate key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages :param ignore_invalid_messages: Ignore invalid messages (Default: True) :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ @@ -204,50 +132,20 @@ async def get_messages( async def get_messages_iterator( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates but will always return all messages. - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from + :param message_filter: Filter to apply to the messages """ page = 1 resp = None while resp is None or len(resp.messages) > 0: resp = await self.get_messages( page=page, - message_type=message_type, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ) page += 1 for message in resp.messages: @@ -272,34 +170,12 @@ async def get_message( @abstractmethod def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ Iterate over current and future matching messages asynchronously. - :param message_type: [DEPRECATED] Type of message to watch - :param message_types: Types of messages to watch - :param content_types: Content types to watch - :param content_keys: Filter by aggregate key - :param refs: References to watch - :param addresses: Addresses to watch - :param tags: Tags to watch - :param hashes: Hashes to watch - :param channels: Channels to watch - :param chains: Chains to watch - :param start_date: Start date from when to watch - :param end_date: End date until when to watch + :param message_filter: Filter to apply to the messages """ pass @@ -318,7 +194,7 @@ async def create_post( sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: """ - Create a POST message on the Aleph network. It is associated with a channel and owned by an account. + Create a POST message on the aleph.im network. It is associated with a channel and owned by an account. :param post_content: The content of the message :param post_type: An arbitrary content type that helps to describe the post_content @@ -368,7 +244,7 @@ async def create_store( sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: """ - Create a STORE message to store a file on the Aleph network. + Create a STORE message to store a file on the aleph.im network. Can be passed either a file path, an IPFS hash or the file's content as raw bytes. @@ -422,7 +298,7 @@ async def create_program( :param persistent: Whether the program should be persistent or not (Default: False) :param encoding: Encoding to use (Default: Encoding.zip) :param volumes: Volumes to mount - :param subscriptions: Patterns of Aleph messages to forward to the program's event receiver + :param subscriptions: Patterns of aleph.im messages to forward to the program's event receiver :param metadata: Metadata to attach to the message """ pass diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index f79f0ceb..837811b7 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -5,8 +5,6 @@ import queue import threading import time -import warnings -from datetime import datetime from io import BytesIO from pathlib import Path from typing import ( @@ -61,7 +59,8 @@ MessageNotFoundError, MultipleMessagesError, ) -from .models import MessagesResponse, Post, PostsResponse +from .models.message import MessageFilter, MessagesResponse +from .models.post import Post, PostFilter, PostsResponse from .utils import check_unix_socket_valid, get_message_type_value logger = logging.getLogger(__name__) @@ -141,18 +140,7 @@ def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[List[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: bool = True, invalid_messages_log_level: int = logging.NOTSET, ) -> MessagesResponse: @@ -160,18 +148,7 @@ def get_messages( self.async_session.get_messages, pagination=pagination, page=page, - message_type=message_type, - message_types=message_types, - content_types=content_types, - content_keys=content_keys, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + message_filter=message_filter, ignore_invalid_messages=ignore_invalid_messages, invalid_messages_log_level=invalid_messages_log_level, ) @@ -210,29 +187,13 @@ def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ) -> PostsResponse: return self._wrap( self.async_session.get_posts, pagination=pagination, page=page, - types=types, - refs=refs, - addresses=addresses, - tags=tags, - hashes=hashes, - channels=channels, - chains=chains, - start_date=start_date, - end_date=end_date, + post_filter=post_filter, ) def download_file(self, file_hash: str) -> bytes: @@ -246,7 +207,7 @@ def download_file_ipfs(self, file_hash: str) -> bytes: def download_file_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_to_buffer, file_hash=file_hash, @@ -255,7 +216,7 @@ def download_file_to_buffer( def download_file_ipfs_to_buffer( self, file_hash: str, output_buffer: Writable[bytes] - ) -> bytes: + ) -> None: return self._wrap( self.async_session.download_file_ipfs_to_buffer, file_hash=file_hash, @@ -264,16 +225,7 @@ def download_file_ipfs_to_buffer( def watch_messages( self, - message_type: Optional[MessageType] = None, - content_types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> Iterable[AlephMessage]: """ Iterate over current and future matching messages synchronously. @@ -286,18 +238,7 @@ def watch_messages( args=( output_queue, self.async_session.api_server, - ( - message_type, - content_types, - refs, - addresses, - tags, - hashes, - channels, - chains, - start_date, - end_date, - ), + message_filter, {}, ), ) @@ -528,15 +469,8 @@ async def __aenter__(self) -> "AlephClient": async def __aexit__(self, exc_type, exc_val, exc_tb): await self.http_session.close() - async def fetch_aggregate( - self, - address: str, - key: str, - limit: int = 100, - ) -> Dict[str, Dict]: + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Dict]: params: Dict[str, Any] = {"keys": key} - if limit: - params["limit"] = limit async with self.http_session.get( f"/api/v0/aggregates/{address}.json", params=params @@ -546,17 +480,12 @@ async def fetch_aggregate( return data.get(key) async def fetch_aggregates( - self, - address: str, - keys: Optional[Iterable[str]] = None, - limit: int = 100, + self, address: str, keys: Optional[Iterable[str]] = None ) -> Dict[str, Dict]: keys_str = ",".join(keys) if keys else "" params: Dict[str, Any] = {} if keys_str: params["keys"] = keys_str - if limit: - params["limit"] = limit async with self.http_session.get( f"/api/v0/aggregates/{address}.json", @@ -570,15 +499,7 @@ async def get_posts( self, pagination: int = 200, page: int = 1, - types: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + post_filter: Optional[PostFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> PostsResponse: @@ -591,31 +512,11 @@ async def get_posts( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if types is not None: - params["types"] = ",".join(types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + if not post_filter: + post_filter = PostFilter() + params = post_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) async with self.http_session.get("/api/v0/posts.json", params=params) as resp: resp.raise_for_status() @@ -722,18 +623,7 @@ async def get_messages( self, pagination: int = 200, page: int = 1, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ignore_invalid_messages: Optional[bool] = True, invalid_messages_log_level: Optional[int] = logging.NOTSET, ) -> MessagesResponse: @@ -746,43 +636,11 @@ async def get_messages( else invalid_messages_log_level ) - params: Dict[str, Any] = dict(pagination=pagination, page=page) - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - print(params["msgTypes"]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date - + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) async with self.http_session.get( "/api/v0/messages.json", params=params ) as resp: @@ -825,8 +683,10 @@ async def get_message( channel: Optional[str] = None, ) -> GenericMessage: messages_response = await self.get_messages( - hashes=[item_hash], - channels=[channel] if channel else None, + message_filter=MessageFilter( + hashes=[item_hash], + channels=[channel] if channel else None, + ) ) if len(messages_response.messages) < 1: raise MessageNotFoundError(f"No such hash {item_hash}") @@ -846,54 +706,11 @@ async def get_message( async def watch_messages( self, - message_type: Optional[MessageType] = None, - message_types: Optional[Iterable[MessageType]] = None, - content_types: Optional[Iterable[str]] = None, - content_keys: Optional[Iterable[str]] = None, - refs: Optional[Iterable[str]] = None, - addresses: Optional[Iterable[str]] = None, - tags: Optional[Iterable[str]] = None, - hashes: Optional[Iterable[str]] = None, - channels: Optional[Iterable[str]] = None, - chains: Optional[Iterable[str]] = None, - start_date: Optional[Union[datetime, float]] = None, - end_date: Optional[Union[datetime, float]] = None, + message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: - params: Dict[str, Any] = dict() - - if message_type is not None: - warnings.warn( - "The message_type parameter is deprecated, please use message_types instead.", - DeprecationWarning, - ) - params["msgType"] = message_type.value - if message_types is not None: - params["msgTypes"] = ",".join([t.value for t in message_types]) - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + if not message_filter: + message_filter = MessageFilter() + params = message_filter.as_http_params() async with self.http_session.ws_connect( "/api/ws0/messages", params=params @@ -1059,7 +876,7 @@ async def _handle_broadcast_deprecated_response( async def _broadcast_deprecated(self, message_dict: Mapping[str, Any]) -> None: """ - Broadcast a message on the Aleph network using the deprecated + Broadcast a message on the aleph.im network using the deprecated /ipfs/pubsub/pub/ endpoint. """ @@ -1097,7 +914,7 @@ async def _broadcast( sync: bool, ) -> MessageStatus: """ - Broadcast a message on the Aleph network. + Broadcast a message on the aleph.im network. Uses the POST /messages/ endpoint or the deprecated /ipfs/pubsub/pub/ endpoint if the first method is not available. @@ -1273,7 +1090,7 @@ async def create_program( # Register the different ways to trigger a VM if subscriptions: - # Trigger on HTTP calls and on Aleph message subscriptions. + # Trigger on HTTP calls and on aleph.im message subscriptions. triggers = { "http": True, "persistent": persistent, @@ -1309,7 +1126,7 @@ async def create_program( "runtime": { "ref": runtime, "use_latest": True, - "comment": "Official Aleph runtime" + "comment": "Official aleph.im runtime" if runtime == settings.DEFAULT_RUNTIME_ID else "", }, diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index 51762925..5f09e1bc 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -21,7 +21,7 @@ class MultipleMessagesError(QueryError): class BroadcastError(Exception): """ - Data could not be broadcast to the Aleph network. + Data could not be broadcast to the aleph.im network. """ pass @@ -29,7 +29,7 @@ class BroadcastError(Exception): class InvalidMessageError(BroadcastError): """ - The message could not be broadcast because it does not follow the Aleph + The message could not be broadcast because it does not follow the aleph.im message specification. """ diff --git a/src/aleph/sdk/models.py b/src/aleph/sdk/models.py deleted file mode 100644 index f5b1072b..00000000 --- a/src/aleph/sdk/models.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Dict, List, Optional, Union - -from aleph_message.models import AlephMessage, BaseMessage, ChainRef, ItemHash -from pydantic import BaseModel, Field - - -class PaginationResponse(BaseModel): - pagination_page: int - pagination_total: int - pagination_per_page: int - pagination_item: str - - -class MessagesResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/messages.json""" - - messages: List[AlephMessage] - pagination_item = "messages" - - -class Post(BaseMessage): - """ - A post is a type of message that can be updated. Over the get_posts API - we get the latest version of a post. - """ - - hash: ItemHash = Field(description="Hash of the content (sha256 by default)") - original_item_hash: ItemHash = Field( - description="Hash of the original content (sha256 by default)" - ) - original_signature: Optional[str] = Field( - description="Cryptographic signature of the original message by the sender" - ) - original_type: str = Field( - description="The original, user-generated 'content-type' of the POST message" - ) - content: Dict[str, Any] = Field( - description="The content.content of the POST message" - ) - type: str = Field(description="The content.type of the POST message") - address: str = Field(description="The address of the sender of the POST message") - ref: Optional[Union[str, ChainRef]] = Field( - description="Other message referenced by this one" - ) - - -class PostsResponse(PaginationResponse): - """Response from an Aleph node API on the path /api/v0/posts.json""" - - posts: List[Post] - pagination_item = "posts" diff --git a/src/aleph/sdk/models/__init__.py b/src/aleph/sdk/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/models/common.py b/src/aleph/sdk/models/common.py new file mode 100644 index 00000000..c7e0dc30 --- /dev/null +++ b/src/aleph/sdk/models/common.py @@ -0,0 +1,29 @@ +from datetime import datetime +from typing import Iterable, Optional, Union + +from pydantic import BaseModel + + +class PaginationResponse(BaseModel): + pagination_page: int + pagination_total: int + pagination_per_page: int + pagination_item: str + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/src/aleph/sdk/models/message.py b/src/aleph/sdk/models/message.py new file mode 100644 index 00000000..4ba6a1b2 --- /dev/null +++ b/src/aleph/sdk/models/message.py @@ -0,0 +1,102 @@ +from datetime import datetime +from typing import Dict, Iterable, List, Optional, Union + +from aleph_message.models import AlephMessage, MessageType + +from .common import PaginationResponse, _date_field_to_float, serialize_list + + +class MessagesResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" + + +class MessageFilter: + """ + A collection of filters that can be applied on message queries. + :param message_types: Filter by message type + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + message_types: Optional[Iterable[MessageType]] + content_types: Optional[Iterable[str]] + content_keys: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + message_types: Optional[Iterable[MessageType]] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.message_types = message_types + self.content_types = content_types + self.content_keys = content_keys + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "msgType": serialize_list( + [type.value for type in self.message_types] + if self.message_types + else None + ), + "contentTypes": serialize_list(self.content_types), + "contentKeys": serialize_list(self.content_keys), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result diff --git a/src/aleph/sdk/models/post.py b/src/aleph/sdk/models/post.py new file mode 100644 index 00000000..09a301c2 --- /dev/null +++ b/src/aleph/sdk/models/post.py @@ -0,0 +1,122 @@ +from datetime import datetime +from typing import Any, Dict, Iterable, List, Optional, Union + +from aleph_message.models import Chain, ItemHash, ItemType, MessageConfirmation +from pydantic import BaseModel, Field + +from .common import PaginationResponse, _date_field_to_float, serialize_list + + +class Post(BaseModel): + """ + A post is a type of message that can be updated. Over the get_posts API + we get the latest version of a post. + """ + + chain: Chain = Field(description="Blockchain this post is associated with") + item_hash: ItemHash = Field(description="Unique hash for this post") + sender: str = Field(description="Address of the sender") + type: str = Field(description="Type of the POST message") + channel: Optional[str] = Field(description="Channel this post is associated with") + confirmed: bool = Field(description="Whether the post is confirmed or not") + content: Dict[str, Any] = Field(description="The content of the POST message") + item_content: Optional[str] = Field( + description="The POSTs content field as serialized JSON, if of type inline" + ) + item_type: ItemType = Field( + description="Type of the item content, usually 'inline' or 'storage' for POSTs" + ) + signature: Optional[str] = Field( + description="Cryptographic signature of the message by the sender" + ) + size: int = Field(description="Size of the post") + time: float = Field(description="Timestamp of the post") + confirmations: List[MessageConfirmation] = Field( + description="Number of confirmations" + ) + original_item_hash: ItemHash = Field(description="Hash of the original content") + original_signature: Optional[str] = Field( + description="Cryptographic signature of the original message" + ) + original_type: str = Field(description="The original type of the message") + hash: ItemHash = Field(description="Hash of the original item") + ref: Optional[Union[str, Any]] = Field( + description="Other message referenced by this one" + ) + + class Config: + allow_extra = False + + +class PostsResponse(PaginationResponse): + """Response from an aleph.im node API on the path /api/v0/posts.json""" + + posts: List[Post] + pagination_item = "posts" + + +class PostFilter: + """ + A collection of filters that can be applied on post queries. + + """ + + types: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.types = types + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "types": serialize_list(self.types), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result diff --git a/tests/integration/itest_forget.py b/tests/integration/itest_forget.py index 29b6c6d9..6b44f76f 100644 --- a/tests/integration/itest_forget.py +++ b/tests/integration/itest_forget.py @@ -1,8 +1,10 @@ from typing import Callable, Dict import pytest +from aleph_message.models import PostMessage from aleph.sdk.client import AuthenticatedAlephClient +from aleph.sdk.models.message import MessageFilter from aleph.sdk.types import Account from .config import REFERENCE_NODE, TARGET_NODE, TEST_CHANNEL @@ -106,11 +108,9 @@ async def test_forget_a_forget_message(fixture_account): async with AuthenticatedAlephClient( account=fixture_account, api_server=TARGET_NODE ) as session: - get_post_response = await session.get_posts(hashes=[post_hash]) - assert len(get_post_response.posts) == 1 - post = get_post_response.posts[0] + get_post_message: PostMessage = await session.get_message(post_hash) - forget_message_hash = post.forgotten_by[0] + forget_message_hash = get_post_message.forgotten_by[0] forget_message, forget_status = await session.forget( hashes=[forget_message_hash], reason="I want to remember this post. Maybe I can forget I forgot it?", @@ -120,8 +120,10 @@ async def test_forget_a_forget_message(fixture_account): print(forget_message) get_forget_message_response = await session.get_messages( - hashes=[forget_message_hash], - channels=[TEST_CHANNEL], + message_filter=MessageFilter( + hashes=[forget_message_hash], + channels=[TEST_CHANNEL], + ) ) assert len(get_forget_message_response.messages) == 1 forget_message = get_forget_message_response.messages[0] diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4f62c0c5..a51b1483 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,8 +1,10 @@ import json from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Any, Callable, Dict, List import pytest as pytest +from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.sol as solana @@ -46,7 +48,77 @@ def substrate_account() -> substrate.DOTAccount: @pytest.fixture -def messages(): +def json_messages(): messages_path = Path(__file__).parent / "messages.json" with open(messages_path) as f: return json.load(f) + + +@pytest.fixture +def aleph_messages() -> List[AlephMessage]: + return [ + AggregateMessage.parse_obj( + { + "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", + "type": "AGGREGATE", + "chain": "ETH", + "sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b", + "item_type": "inline", + "item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}', + "content": { + "key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance", + "time": 1692026263.662, + "address": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "content": { + "hello": "world", + }, + }, + "time": 1692026263.662, + "channel": "UNSLASHED", + "size": 734, + "confirmations": [], + "confirmed": False, + } + ), + PostMessage.parse_obj( + { + "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", + "type": "POST", + "chain": "SOL", + "sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b", + "item_type": "storage", + "item_content": None, + "content": { + "time": 1692026021.1257718, + "type": "aleph-network-metrics", + "address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "ref": "0123456789abcdef", + "content": { + "tags": ["mainnet"], + "hello": "world", + "version": "1.0", + }, + }, + "time": 1692026021.132849, + "channel": "aleph-scoring", + "size": 122537, + "confirmations": [], + "confirmed": False, + } + ), + ] + + +@pytest.fixture +def raw_messages_response(aleph_messages) -> Callable[[int], Dict[str, Any]]: + return lambda page: { + "messages": [message.dict() for message in aleph_messages] + if int(page) == 1 + else [], + "pagination_item": "messages", + "pagination_page": int(page), + "pagination_per_page": max(len(aleph_messages), 20), + "pagination_total": len(aleph_messages) if page == 1 else 0, + } diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index db788e0b..72c47706 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -3,11 +3,12 @@ from unittest.mock import AsyncMock import pytest -from aleph_message.models import MessagesResponse +from aleph_message.models import MessagesResponse, MessageType from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings -from aleph.sdk.models import PostsResponse +from aleph.sdk.models.message import MessageFilter +from aleph.sdk.models.post import PostFilter, PostsResponse def make_mock_session(get_return_value: Dict[str, Any]) -> AlephClient: @@ -67,7 +68,12 @@ async def test_fetch_aggregates(): @pytest.mark.asyncio async def test_get_posts(): async with AlephClient(api_server=settings.API_HOST) as session: - response: PostsResponse = await session.get_posts() + response: PostsResponse = await session.get_posts( + pagination=2, + post_filter=PostFilter( + channels=["TEST"], + ), + ) posts = response.posts assert len(posts) > 1 @@ -78,6 +84,9 @@ async def test_get_messages(): async with AlephClient(api_server=settings.API_HOST) as session: response: MessagesResponse = await session.get_messages( pagination=2, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages diff --git a/tests/unit/test_chain_ethereum.py b/tests/unit/test_chain_ethereum.py index dea58c69..9a602b3d 100644 --- a/tests/unit/test_chain_ethereum.py +++ b/tests/unit/test_chain_ethereum.py @@ -82,8 +82,8 @@ async def test_verify_signature(ethereum_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(ethereum_account, messages): - message = messages[1] +async def test_verify_signature_with_processed_message(ethereum_account, json_messages): + message = json_messages[1] verify_signature( message["signature"], message["sender"], get_verification_buffer(message) ) diff --git a/tests/unit/test_chain_solana.py b/tests/unit/test_chain_solana.py index 5088158a..07b67602 100644 --- a/tests/unit/test_chain_solana.py +++ b/tests/unit/test_chain_solana.py @@ -103,8 +103,8 @@ async def test_verify_signature(solana_account): @pytest.mark.asyncio -async def test_verify_signature_with_processed_message(solana_account, messages): - message = messages[0] +async def test_verify_signature_with_processed_message(solana_account, json_messages): + message = json_messages[0] signature = json.loads(message["signature"])["signature"] verify_signature(signature, message["sender"], get_verification_buffer(message)) diff --git a/tests/unit/test_synchronous_get.py b/tests/unit/test_synchronous_get.py index eee26dcf..0788a1ab 100644 --- a/tests/unit/test_synchronous_get.py +++ b/tests/unit/test_synchronous_get.py @@ -2,14 +2,16 @@ from aleph.sdk.client import AlephClient from aleph.sdk.conf import settings +from aleph.sdk.models.message import MessageFilter def test_get_post_messages(): with AlephClient(api_server=settings.API_HOST) as session: - # TODO: Remove deprecated message_type parameter after message_types changes on pyaleph are deployed response: MessagesResponse = session.get_messages( pagination=2, - message_type=MessageType.post, + message_filter=MessageFilter( + message_types=[MessageType.post], + ), ) messages = response.messages