From 26146b7a18495e527932e4bba1c8bf323ab06387 Mon Sep 17 00:00:00 2001 From: mhh Date: Tue, 5 Sep 2023 14:45:30 +0200 Subject: [PATCH] add MessageCache and DomainNode, based on peewee ORM & SQLite --- setup.cfg | 3 + src/aleph/sdk/conf.py | 9 + src/aleph/sdk/node.py | 749 ++++++++++++++++++++++++++++++++++++ tests/unit/conftest.py | 70 ++++ tests/unit/test_node.py | 255 ++++++++++++ tests/unit/test_node_get.py | 231 +++++++++++ 6 files changed, 1317 insertions(+) create mode 100644 src/aleph/sdk/node.py create mode 100644 tests/unit/test_node.py create mode 100644 tests/unit/test_node_get.py diff --git a/setup.cfg b/setup.cfg index 48eb7f9b..fb27281d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,6 +79,7 @@ testing = black isort flake8 + peewee mqtt = aiomqtt<=0.1.3 certifi @@ -103,6 +104,8 @@ ledger = ledgereth==0.9.0 docs = sphinxcontrib-plantuml +cache = + peewee [options.entry_points] # Add here console scripts like: diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 264c8c9f..2ae20e45 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -33,6 +33,15 @@ class Settings(BaseSettings): CODE_USES_SQUASHFS: bool = which("mksquashfs") is not None # True if command exists + CACHE_DATABASE_PATH: Path = Field( + default=Path(":memory:"), # can also be :memory: for in-memory caching + description="Path to the cache database", + ) + CACHE_FILES_PATH: Path = Field( + default=Path("cache", "files"), + description="Path to the cache files", + ) + class Config: env_prefix = "ALEPH_" case_sensitive = False diff --git a/src/aleph/sdk/node.py b/src/aleph/sdk/node.py new file mode 100644 index 00000000..a9548e67 --- /dev/null +++ b/src/aleph/sdk/node.py @@ -0,0 +1,749 @@ +import asyncio +import json +import logging +import typing +from datetime import datetime +from functools import partial +from pathlib import Path +from typing import ( + Any, + AsyncIterable, + Coroutine, + Dict, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from aleph_message import MessagesResponse, parse_message +from aleph_message.models import ( + AlephMessage, + Chain, + ItemHash, + MessageConfirmation, + MessageType, +) +from aleph_message.models.execution.base import Encoding +from aleph_message.status import MessageStatus +from peewee import ( + BooleanField, + CharField, + FloatField, + IntegerField, + Model, + SqliteDatabase, +) +from playhouse.shortcuts import model_to_dict +from playhouse.sqlite_ext import JSONField +from pydantic import BaseModel + +from aleph.sdk import AuthenticatedAlephClient +from aleph.sdk.base import AlephClientBase, AuthenticatedAlephClientBase +from aleph.sdk.conf import settings +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.models import PostsResponse +from aleph.sdk.types import GenericMessage, StorageEnum + +db = SqliteDatabase(settings.CACHE_DATABASE_PATH) +T = TypeVar("T", bound=BaseModel) + + +class JSONDictEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, BaseModel): + return obj.dict() + return json.JSONEncoder.default(self, obj) + + +pydantic_json_dumps = partial(json.dumps, cls=JSONDictEncoder) + + +class PydanticField(JSONField, Generic[T]): + """ + A field for storing pydantic model types as JSON in a database. Uses json for serialization. + """ + + type: T + + def __init__(self, *args, **kwargs): + self.type = kwargs.pop("type") + super().__init__(*args, **kwargs) + + def db_value(self, value: Optional[T]) -> Optional[str]: + if value is None: + return None + return value.json() + + def python_value(self, value: Optional[str]) -> Optional[T]: + if value is None: + return None + return self.type.parse_raw(value) + + +class MessageModel(Model): + """ + A simple database model for storing AlephMessage objects. + """ + + item_hash = CharField(primary_key=True) + chain = CharField(5) + type = CharField(9) + sender = CharField() + channel = CharField(null=True) + confirmations: PydanticField[MessageConfirmation] = PydanticField( + type=MessageConfirmation, null=True + ) + confirmed = BooleanField(null=True) + signature = CharField(null=True) + size = IntegerField(null=True) + time = FloatField() + item_type = CharField(7) + item_content = CharField(null=True) + hash_type = CharField(6, null=True) + content = JSONField(json_dumps=pydantic_json_dumps) + forgotten_by = CharField(null=True) + tags = JSONField(json_dumps=pydantic_json_dumps, null=True) + key = CharField(null=True) + ref = CharField(null=True) + content_type = CharField(null=True) + + class Meta: + database = db + + +def message_to_model(message: AlephMessage) -> Dict: + return { + "item_hash": str(message.item_hash), + "chain": message.chain, + "type": message.type, + "sender": message.sender, + "channel": message.channel, + "confirmations": message.confirmations[0] if message.confirmations else None, + "confirmed": message.confirmed, + "signature": message.signature, + "size": message.size, + "time": message.time, + "item_type": message.item_type, + "item_content": message.item_content, + "hash_type": message.hash_type, + "content": message.content, + "forgotten_by": message.forgotten_by[0] if message.forgotten_by else None, + "tags": message.content.content.get("tags", None) + if hasattr(message.content, "content") + else None, + "key": message.content.key if hasattr(message.content, "key") else None, + "ref": message.content.ref if hasattr(message.content, "ref") else None, + "content_type": message.content.type + if hasattr(message.content, "type") + else None, + } + + +def model_to_message(item: Any) -> AlephMessage: + item.confirmations = [item.confirmations] if item.confirmations else [] + item.forgotten_by = [item.forgotten_by] if item.forgotten_by else None + + to_exclude = [ + MessageModel.tags, + MessageModel.ref, + MessageModel.key, + MessageModel.content_type, + ] + + item_dict = model_to_dict(item, exclude=to_exclude) + return parse_message(item_dict) + + +def query_field(field_name, field_values: Iterable[str]): + field = getattr(MessageModel, field_name) + values = list(field_values) + + if len(values) == 1: + return field == values[0] + return field.in_(values) + + +def get_message_query( + message_type: Optional[MessageType] = None, + content_keys: Optional[Iterable[str]] = None, + content_types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, +): + query = MessageModel.select().order_by(MessageModel.time.desc()) + conditions = [] + if message_type: + conditions.append(query_field("type", [message_type.value])) + if content_keys: + conditions.append(query_field("key", content_keys)) + if content_types: + conditions.append(query_field("content_type", content_types)) + if refs: + conditions.append(query_field("ref", refs)) + if addresses: + conditions.append(query_field("sender", addresses)) + if tags: + for tag in tags: + conditions.append(MessageModel.tags.contains(tag)) + if hashes: + conditions.append(query_field("item_hash", hashes)) + if channels: + conditions.append(query_field("channel", channels)) + if chains: + conditions.append(query_field("chain", chains)) + if start_date: + conditions.append(MessageModel.time >= start_date) + if end_date: + conditions.append(MessageModel.time <= end_date) + + if conditions: + query = query.where(*conditions) + return query + + +class MessageCache(AlephClientBase): + """ + A wrapper around a sqlite3 database for caching AlephMessage objects. + + It can be used independently of a DomainNode to implement any kind of caching strategy. + """ + + _instance_count = 0 # Class-level counter for active instances + + def __init__(self): + if db.is_closed(): + db.connect() + if not MessageModel.table_exists(): + db.create_tables([MessageModel]) + + MessageCache._instance_count += 1 + + def __del__(self): + MessageCache._instance_count -= 1 + + if MessageCache._instance_count == 0: + db.close() + + def __getitem__(self, item_hash: Union[ItemHash, str]) -> Optional[AlephMessage]: + try: + item = MessageModel.get(MessageModel.item_hash == str(item_hash)) + except MessageModel.DoesNotExist: + return None + return model_to_message(item) + + def __delitem__(self, item_hash: Union[ItemHash, str]): + MessageModel.delete().where(MessageModel.item_hash == str(item_hash)).execute() + + def __contains__(self, item_hash: Union[ItemHash, str]) -> bool: + return ( + MessageModel.select() + .where(MessageModel.item_hash == str(item_hash)) + .exists() + ) + + def __len__(self): + return MessageModel.select().count() + + def __iter__(self) -> Iterator[AlephMessage]: + """ + Iterate over all messages in the cache, the latest first. + """ + for item in iter(MessageModel.select().order_by(-MessageModel.time)): + yield model_to_message(item) + + def __repr__(self) -> str: + return f"" + + def __str__(self) -> str: + return repr(self) + + @staticmethod + def add(messages: Union[AlephMessage, Iterable[AlephMessage]]): + if isinstance(messages, typing.get_args(AlephMessage)): + messages = [messages] + + data_source = (message_to_model(message) for message in messages) + MessageModel.insert_many(data_source).on_conflict_replace().execute() + + @staticmethod + def get( + item_hashes: Union[Union[ItemHash, str], Iterable[Union[ItemHash, str]]] + ) -> List[AlephMessage]: + """ + Get many messages from the cache by their item hash. + """ + if not isinstance(item_hashes, list): + item_hashes = [item_hashes] + item_hashes = [str(item_hash) for item_hash in item_hashes] + items = ( + MessageModel.select() + .where(MessageModel.item_hash.in_(item_hashes)) + .execute() + ) + return [model_to_message(item) for item in items] + + def listen_to(self, message_stream: AsyncIterable[AlephMessage]) -> Coroutine: + """ + Listen to a stream of messages and add them to the cache. + """ + + async def _listen(): + async for message in message_stream: + self.add(message) + print(f"Added message {message.item_hash} to cache") + + return _listen() + + async def fetch_aggregate( + self, address: str, key: str, limit: int = 100 + ) -> Dict[str, Dict]: + item = ( + MessageModel.select() + .where(MessageModel.type == MessageType.aggregate.value) + .where(MessageModel.sender == address) + .where(MessageModel.key == key) + .order_by(MessageModel.time.desc()) + .first() + ) + return item.content["content"] + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None, limit: int = 100 + ) -> Dict[str, Dict]: + query = ( + MessageModel.select() + .where(MessageModel.type == MessageType.aggregate.value) + .where(MessageModel.sender == address) + .order_by(MessageModel.time.desc()) + ) + if keys: + query = query.where(MessageModel.key.in_(keys)) + query = query.limit(limit) + return {item.key: item.content["content"] for item in list(query)} + + async def get_posts( + self, + pagination: int = 200, + page: int = 1, + types: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> PostsResponse: + query = get_message_query( + message_type=MessageType.post, + content_types=types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + query = query.paginate(page, pagination) + + posts = [model_to_message(item) for item in list(query)] + + return PostsResponse( + posts=posts, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="posts", + ) + + async def download_file(self, file_hash: str) -> bytes: + raise NotImplementedError + + async def get_messages( + self, + pagination: int = 200, + page: int = 1, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ) -> MessagesResponse: + """ + Get many messages from the cache. + """ + query = get_message_query( + message_type=message_type, + content_keys=content_keys, + content_types=content_types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + query = query.paginate(page, pagination) + + messages = [model_to_message(item) for item in list(query)] + + return MessagesResponse( + messages=messages, + pagination_page=page, + pagination_per_page=pagination, + pagination_total=query.count(), + pagination_item="messages", + ) + + async def get_message( + self, + item_hash: str, + message_type: Optional[Type[GenericMessage]] = None, + channel: Optional[str] = None, + ) -> GenericMessage: + """ + Get a single message from the cache. + """ + query = MessageModel.select().where(MessageModel.item_hash == item_hash) + + if message_type: + query = query.where(MessageModel.type == message_type.value) + if channel: + query = query.where(MessageModel.channel == channel) + + item = query.first() + + if item: + return model_to_message(item) + + raise MessageNotFoundError(f"No such hash {item_hash}") + + async def watch_messages( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ) -> AsyncIterable[AlephMessage]: + """ + Watch messages from the cache. + """ + query = get_message_query( + message_type=message_type, + content_keys=content_keys, + content_types=content_types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + async for item in query: + yield model_to_message(item) + + +class DomainNode(MessageCache, AuthenticatedAlephClientBase): + """ + A Domain Node is a queryable proxy for Aleph Messages that are stored in a database cache and/or in the Aleph + network. + + It synchronizes with the network on a subset of the messages by listening to the network and storing the + messages in the cache. The user may define the subset by specifying a channels, tags, senders, chains, + message types, and/or a time window. + """ + + def __init__( + self, + session: AuthenticatedAlephClient, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_type: Optional[MessageType] = None, + ): + super().__init__() + self.session = session + self.channels = channels + self.tags = tags + self.addresses = addresses + self.chains = chains + self.message_type = message_type + + # start listening to the network and storing messages in the cache + asyncio.get_event_loop().create_task( + self.listen_to( + self.session.watch_messages( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_type=self.message_type, + ) + ) + ) + + # synchronize with past messages + asyncio.get_event_loop().run_until_complete( + self.synchronize( + channels=self.channels, + tags=self.tags, + addresses=self.addresses, + chains=self.chains, + message_type=self.message_type, + ) + ) + + async def __aenter__(self) -> "DomainNode": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + async def synchronize( + self, + channels: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + chains: Optional[Iterable[Chain]] = None, + message_type: Optional[MessageType] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + """ + Synchronize with past messages. + """ + chunk_size = 200 + messages = [] + async for message in self.session.get_messages_iterator( + channels=channels, + tags=tags, + addresses=addresses, + chains=chains, + message_type=message_type, + start_date=start_date, + end_date=end_date, + ): + messages.append(message) + if len(messages) >= chunk_size: + self.add(messages) + messages = [] + if messages: + self.add(messages) + + async def download_file(self, file_hash: str) -> bytes: + """ + Opens a file that has been locally stored by its hash. + """ + try: + with open(self._file_path(file_hash), "rb") as f: + return f.read() + except FileNotFoundError: + file = await self.session.download_file(file_hash) + self._file_path(file_hash).parent.mkdir(parents=True, exist_ok=True) + with open(self._file_path(file_hash), "wb") as f: + f.write(file) + return file + + @staticmethod + def _file_path(file_hash: str) -> Path: + return settings.CACHE_FILES_PATH / Path(file_hash) + + async def create_post( + self, + post_content: Any, + post_type: str, + ref: Optional[str] = None, + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + storage_engine: StorageEnum = StorageEnum.storage, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_post( + post_content=post_content, + post_type=post_type, + ref=ref, + address=address, + channel=channel, + inline=inline, + storage_engine=storage_engine, + sync=sync, + ) + # WARNING: this can cause inconsistencies if the message is dropped/rejected by the aleph node + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_aggregate( + self, + key: str, + content: Mapping[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_aggregate( + key=key, + content=content, + address=address, + channel=channel, + inline=inline, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_store( + self, + address: Optional[str] = None, + file_content: Optional[bytes] = None, + file_path: Optional[Union[str, Path]] = None, + file_hash: Optional[str] = None, + guess_mime_type: bool = False, + ref: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + extra_fields: Optional[dict] = None, + channel: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_store( + address=address, + file_content=file_content, + file_path=file_path, + file_hash=file_hash, + guess_mime_type=guess_mime_type, + ref=ref, + storage_engine=storage_engine, + extra_fields=extra_fields, + channel=channel, + sync=sync, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def create_program( + self, + program_ref: str, + entrypoint: str, + runtime: str, + environment_variables: Optional[Mapping[str, str]] = None, + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + persistent: bool = False, + encoding: Encoding = Encoding.zip, + volumes: Optional[List[Mapping]] = None, + subscriptions: Optional[List[Mapping]] = None, + metadata: Optional[Mapping[str, Any]] = None, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.create_program( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + environment_variables=environment_variables, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, + persistent=persistent, + encoding=encoding, + volumes=volumes, + subscriptions=subscriptions, + metadata=metadata, + ) + if status in [MessageStatus.PENDING, MessageStatus.PROCESSED]: + self.add(resp) + return resp, status + + async def forget( + self, + hashes: List[str], + reason: Optional[str], + storage_engine: StorageEnum = StorageEnum.storage, + channel: Optional[str] = None, + address: Optional[str] = None, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.forget( + hashes=hashes, + reason=reason, + storage_engine=storage_engine, + channel=channel, + address=address, + sync=sync, + ) + del self[resp.item_hash] + return resp, status + + async def submit( + self, + content: Dict[str, Any], + message_type: MessageType, + channel: Optional[str] = None, + storage_engine: StorageEnum = StorageEnum.storage, + allow_inlining: bool = True, + sync: bool = False, + ) -> Tuple[AlephMessage, MessageStatus]: + resp, status = await self.session.submit( + content=content, + message_type=message_type, + channel=channel, + storage_engine=storage_engine, + allow_inlining=allow_inlining, + sync=sync, + ) + # WARNING: this can cause inconsistencies if the message is dropped/rejected by the aleph node + if status in [MessageStatus.PROCESSED, MessageStatus.PENDING]: + self.add(resp) + return resp, status diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 9952f847..311d32f3 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,7 +1,9 @@ from pathlib import Path from tempfile import NamedTemporaryFile +from typing import List import pytest as pytest +from aleph_message.models import AggregateMessage, AlephMessage, PostMessage import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.sol as solana @@ -34,3 +36,71 @@ def tezos_account() -> tezos.TezosAccount: with NamedTemporaryFile(delete=False) as private_key_file: private_key_file.close() yield tezos.get_fallback_account(path=Path(private_key_file.name)) + + +@pytest.fixture +def messages() -> List[AlephMessage]: + return [ + AggregateMessage.parse_obj( + { + "item_hash": "5b26d949fe05e38f535ef990a89da0473f9d700077cced228f2d36e73fca1fd6", + "type": "AGGREGATE", + "chain": "ETH", + "sender": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "signature": "0xca5825b6b93390482b436cb7f28b4628f8c9f56dc6af08260c869b79dd6017c94248839bd9fd0ffa1230dc3b1f4f7572a8d1f6fed6c6e1fb4d70ccda0ab5d4f21b", + "item_type": "inline", + "item_content": '{"address":"0x51A58800b26AA1451aaA803d1746687cB88E0501","key":"0xce844d79e5c0c325490c530aa41e8f602f0b5999binance","content":{"1692026263168":{"version":"x25519-xsalsa20-poly1305","nonce":"RT4Lbqs7Xzk+op2XC+VpXgwOgg21BotN","ephemPublicKey":"CVW8ECE3m8BepytHMTLan6/jgIfCxGdnKmX47YirF08=","ciphertext":"VuGJ9vMkJSbaYZCCv6Zemx4ixeb+9IW8H1vFB9vLtz1a8d87R4BfYUisLoCQxRkeUXqfW0/KIGQ5idVjr8Yj7QnKglW5AJ8UX7wEWMhiRFLatpWP8P9FI2n8Z7Rblu7Oz/OeKnuljKL3KsalcUQSsFa/1qACsIoycPZ6Wq6t1mXxVxxJWzClLyKRihv1pokZGT9UWxh7+tpoMGlRdYainyAt0/RygFw+r8iCMOilHnyv4ndLkKQJXyttb0tdNr/gr57+9761+trioGSysLQKZQWW6Ih6aE8V9t3BenfzYwiCnfFw3YAAKBPMdm9QdIETyrOi7YhD/w==","sha256":"bbeb499f681aed2bc18b6f3b6a30d25254bd30fbfde43444e9085f3bcd075c3c"}},"time":1692026263.662}', + "content": { + "key": "0xce844d79e5c0c325490c530aa41e8f602f0b5999binance", + "time": 1692026263.662, + "address": "0x51A58800b26AA1451aaA803d1746687cB88E0501", + "content": { + "hello": "world", + }, + }, + "time": 1692026263.662, + "channel": "UNSLASHED", + "size": 734, + "confirmations": [], + "confirmed": False, + } + ), + PostMessage.parse_obj( + { + "item_hash": "70f3798fdc68ce0ee03715a5547ee24e2c3e259bf02e3f5d1e4bf5a6f6a5e99f", + "type": "POST", + "chain": "SOL", + "sender": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "signature": "0x91616ee45cfba55742954ff87ebf86db4988bcc5e3334b49a4caa6436e28e28d4ab38667cbd4bfb8903abf8d71f70d9ceb2c0a8d0a15c04fc1af5657f0050c101b", + "item_type": "storage", + "item_content": None, + "content": { + "time": 1692026021.1257718, + "type": "aleph-network-metrics", + "address": "0x4D52380D3191274a04846c89c069E6C3F2Ed94e4", + "ref": "0123456789abcdef", + "content": { + "tags": ["mainnet"], + "hello": "world", + "version": "1.0", + }, + }, + "time": 1692026021.132849, + "channel": "aleph-scoring", + "size": 122537, + "confirmations": [], + "confirmed": False, + } + ), + ] + + +@pytest.fixture +def raw_messages_response(messages): + return { + "messages": [message.dict() for message in messages], + "pagination_item": "messages", + "pagination_page": 1, + "pagination_per_page": 20, + "pagination_total": 2, + } diff --git a/tests/unit/test_node.py b/tests/unit/test_node.py new file mode 100644 index 00000000..0b844e50 --- /dev/null +++ b/tests/unit/test_node.py @@ -0,0 +1,255 @@ +import json +import os +from pathlib import Path +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + +import pytest as pytest +from aleph_message.models import ( + AggregateMessage, + ForgetMessage, + MessageType, + PostMessage, + ProgramMessage, + StoreMessage, +) +from aleph_message.status import MessageStatus + +from aleph.sdk import AuthenticatedAlephClient +from aleph.sdk.conf import settings +from aleph.sdk.node import DomainNode +from aleph.sdk.types import Account, StorageEnum + + +class MockPostResponse: + def __init__(self, response_message: Any, sync: bool): + self.response_message = response_message + self.sync = sync + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 if self.sync else 202 + + def raise_for_status(self): + if self.status not in [200, 202]: + raise Exception("Bad status code") + + async def json(self): + message_status = "processed" if self.sync else "pending" + return { + "message_status": message_status, + "publication_status": {"status": "success", "failed": []}, + "hash": "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + "message": self.response_message, + } + + async def text(self): + return json.dumps(await self.json()) + + +class MockGetResponse: + def __init__(self, response): + self.response = response + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... + + @property + def status(self): + return 200 + + def raise_for_status(self): + if self.status != 200: + raise Exception("Bad status code") + + async def json(self): + return self.response + + +@pytest.fixture +def mock_session_with_two_messages( + ethereum_account: Account, raw_messages_response: Dict[str, Any] +) -> AuthenticatedAlephClient: + http_session = AsyncMock() + http_session.post = MagicMock() + http_session.post.side_effect = lambda *args, **kwargs: MockPostResponse( + response_message={ + "type": "post", + "channel": "TEST", + "content": {"Hello": "World"}, + "key": "QmBlahBlahBlah", + "item_hash": "QmBlahBlahBlah", + }, + sync=kwargs.get("sync", False), + ) + http_session.get = MagicMock() + http_session.get.return_value = MockGetResponse(raw_messages_response) + + client = AuthenticatedAlephClient( + account=ethereum_account, api_server="http://localhost" + ) + client.http_session = http_session + + return client + + +@pytest.mark.asyncio +def test_node_init(mock_session_with_two_messages): + node = DomainNode(session=mock_session_with_two_messages) + assert node.session == mock_session_with_two_messages + assert len(node) >= 2 + + +@pytest.fixture +def mock_node_with_post_success(mock_session_with_two_messages) -> DomainNode: + node = DomainNode(session=mock_session_with_two_messages) + return node + + +@pytest.mark.asyncio +async def test_create_post(mock_node_with_post_success): + async with mock_node_with_post_success as session: + content = {"Hello": "World"} + + post_message, message_status = await session.create_post( + post_content=content, + post_type="TEST", + channel="TEST", + sync=False, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(post_message, PostMessage) + assert message_status == MessageStatus.PENDING + + +@pytest.mark.asyncio +async def test_create_aggregate(mock_node_with_post_success): + async with mock_node_with_post_success as session: + aggregate_message, message_status = await session.create_aggregate( + key="hello", + content={"Hello": "world"}, + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(aggregate_message, AggregateMessage) + + +@pytest.mark.asyncio +async def test_create_store(mock_node_with_post_success): + mock_ipfs_push_file = AsyncMock() + mock_ipfs_push_file.return_value = "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + + mock_node_with_post_success.ipfs_push_file = mock_ipfs_push_file + + async with mock_node_with_post_success as node: + _ = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + _ = await node.create_store( + file_hash="QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy", + channel="TEST", + storage_engine=StorageEnum.ipfs, + ) + + mock_storage_push_file = AsyncMock() + mock_storage_push_file.return_value = ( + "QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy" + ) + mock_node_with_post_success.storage_push_file = mock_storage_push_file + async with mock_node_with_post_success as node: + store_message, message_status = await node.create_store( + file_content=b"HELLO", + channel="TEST", + storage_engine=StorageEnum.storage, + ) + + assert mock_node_with_post_success.session.http_session.post.called + assert isinstance(store_message, StoreMessage) + + +@pytest.mark.asyncio +async def test_create_program(mock_node_with_post_success): + async with mock_node_with_post_success as node: + program_message, message_status = await node.create_program( + program_ref="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + entrypoint="main:app", + runtime="facefacefacefacefacefacefacefacefacefacefacefacefacefacefaceface", + channel="TEST", + metadata={"tags": ["test"]}, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(program_message, ProgramMessage) + + +@pytest.mark.asyncio +async def test_forget(mock_node_with_post_success): + async with mock_node_with_post_success as node: + forget_message, message_status = await node.forget( + hashes=["QmRTV3h1jLcACW4FRfdisokkQAk4E4qDhUzGpgdrd4JAFy"], + reason="GDPR", + channel="TEST", + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert isinstance(forget_message, ForgetMessage) + + +@pytest.mark.asyncio +async def test_download_file(mock_node_with_post_success): + mock_node_with_post_success.session.download_file = AsyncMock() + mock_node_with_post_success.session.download_file.return_value = b"HELLO" + + # remove file locally + if os.path.exists(settings.CACHE_FILES_PATH / Path("QmAndSoOn")): + os.remove(settings.CACHE_FILES_PATH / Path("QmAndSoOn")) + + # fetch from mocked response + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert mock_node_with_post_success.session.http_session.get.called_once + assert file_content == b"HELLO" + + # fetch cached + async with mock_node_with_post_success as node: + file_content = await node.download_file( + file_hash="QmAndSoOn", + ) + + assert file_content == b"HELLO" + + +@pytest.mark.asyncio +async def test_submit_message(mock_node_with_post_success): + content = {"Hello": "World"} + async with mock_node_with_post_success as node: + message, status = await node.submit( + content={ + "address": "0x1234567890123456789012345678901234567890", + "time": 1234567890, + "type": "TEST", + "content": content, + }, + message_type=MessageType.post, + ) + + assert mock_node_with_post_success.session.http_session.post.called_once + assert message.content.content == content + assert status == MessageStatus.PENDING diff --git a/tests/unit/test_node_get.py b/tests/unit/test_node_get.py new file mode 100644 index 00000000..48bff3b8 --- /dev/null +++ b/tests/unit/test_node_get.py @@ -0,0 +1,231 @@ +import json +from hashlib import sha256 +from typing import List + +import pytest +from aleph_message.models import ( + AlephMessage, + Chain, + MessageType, + PostContent, + PostMessage, +) + +from aleph.sdk.chains.ethereum import get_fallback_account +from aleph.sdk.exceptions import MessageNotFoundError +from aleph.sdk.node import MessageCache + + +@pytest.mark.asyncio +async def test_base(messages): + # test add_many + cache = MessageCache() + cache.add(messages) + assert len(cache) == len(messages) + + item_hashes = [message.item_hash for message in messages] + cached_messages = cache.get(item_hashes) + assert len(cached_messages) == len(messages) + + for message in messages: + assert cache[message.item_hash] == message + + for message in messages: + assert message.item_hash in cache + + for message in cache: + del cache[message.item_hash] + assert message.item_hash not in cache + + assert len(cache) == 0 + del cache + + +class TestMessageQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, messages): + self.messages = messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_iterate(self): + assert len(self.cache) == len(self.messages) + for message in self.cache: + assert message in self.messages + + @pytest.mark.asyncio + async def test_addresses(self): + items = ( + await self.cache.get_messages(addresses=[self.messages[0].sender]) + ).messages + assert items[0] == self.messages[0] + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len((await self.cache.get_messages(tags=["thistagdoesnotexist"])).messages) + == 0 + ) + + @pytest.mark.asyncio + async def test_message_type(self): + assert (await self.cache.get_messages(message_type=MessageType.post)).messages[ + 0 + ] == self.messages[1] + + @pytest.mark.asyncio + async def test_refs(self): + assert ( + await self.cache.get_messages(refs=[self.messages[1].content.ref]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_hashes(self): + assert ( + await self.cache.get_messages(hashes=[self.messages[0].item_hash]) + ).messages[0] == self.messages[0] + + @pytest.mark.asyncio + async def test_pagination(self): + assert len((await self.cache.get_messages(pagination=1)).messages) == 1 + + @pytest.mark.asyncio + async def test_content_types(self): + assert ( + await self.cache.get_messages(content_types=[self.messages[1].content.type]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_channels(self): + assert ( + await self.cache.get_messages(channels=[self.messages[1].channel]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_chains(self): + assert ( + await self.cache.get_messages(chains=[self.messages[1].chain]) + ).messages[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_content_keys(self): + assert ( + await self.cache.get_messages(content_keys=[self.messages[0].content.key]) + ).messages[0] == self.messages[0] + + +class TestPostQueries: + messages: List[AlephMessage] + cache: MessageCache + + @pytest.fixture(autouse=True) + def class_setup(self, messages): + self.messages = messages + self.cache = MessageCache() + self.cache.add(self.messages) + + def class_teardown(self): + del self.cache + + @pytest.mark.asyncio + async def test_addresses(self): + items = (await self.cache.get_posts(addresses=[self.messages[1].sender])).posts + assert items[0] == self.messages[1] + + @pytest.mark.asyncio + async def test_tags(self): + assert ( + len((await self.cache.get_posts(tags=["thistagdoesnotexist"])).posts) == 0 + ) + + @pytest.mark.asyncio + async def test_types(self): + assert ( + len((await self.cache.get_posts(types=["thistypedoesnotexist"])).posts) == 0 + ) + + @pytest.mark.asyncio + async def test_channels(self): + assert (await self.cache.get_posts(channels=[self.messages[1].channel])).posts[ + 0 + ] == self.messages[1] + + @pytest.mark.asyncio + async def test_chains(self): + assert (await self.cache.get_posts(chains=[self.messages[1].chain])).posts[ + 0 + ] == self.messages[1] + + +@pytest.mark.asyncio +async def test_message_cache_listener(): + async def mock_message_stream(): + for i in range(3): + content = PostContent( + content={"hello": f"world{i}"}, + type="test", + address=get_fallback_account().get_address(), + time=0, + ) + message = PostMessage( + sender=get_fallback_account().get_address(), + item_hash=sha256(json.dumps(content.dict()).encode()).hexdigest(), + chain=Chain.ETH.value, + type=MessageType.post.value, + item_type="inline", + time=0, + content=content, + item_content=json.dumps(content.dict()), + ) + yield message + + cache = MessageCache() + # test listener + coro = cache.listen_to(mock_message_stream()) + await coro + assert len(cache) >= 3 + + +@pytest.mark.asyncio +async def test_fetch_aggregate(messages): + cache = MessageCache() + cache.add(messages) + + aggregate = await cache.fetch_aggregate(messages[0].sender, messages[0].content.key) + + assert aggregate == messages[0].content.content + + +@pytest.mark.asyncio +async def test_fetch_aggregates(messages): + cache = MessageCache() + cache.add(messages) + + aggregates = await cache.fetch_aggregates(messages[0].sender) + + assert aggregates == {messages[0].content.key: messages[0].content.content} + + +@pytest.mark.asyncio +async def test_get_message(messages): + cache = MessageCache() + cache.add(messages) + + message: AlephMessage = await cache.get_message(messages[0].item_hash) + + assert message == messages[0] + + +@pytest.mark.asyncio +async def test_get_message_fail(): + cache = MessageCache() + + with pytest.raises(MessageNotFoundError): + await cache.get_message("0x1234567890123456789012345678901234567890")