Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add security module: Verify messages on fetch #113

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/aleph/sdk/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ async def get_posts(
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> PostsResponse:
"""
Fetch a list of posts from the network.
Expand All @@ -83,25 +84,35 @@ async def get_posts(
:param post_filter: Filter to apply to the posts (Default: None)
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Verify the signatures of the messages (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

async def get_posts_iterator(
self,
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> AsyncIterable[PostMessage]:
"""
Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates
but will always return all posts.

:param post_filter: Filter to apply to the posts (Default: None)
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Verify the signatures of the messages (Default: False)
"""
page = 1
resp = None
while resp is None or len(resp.posts) > 0:
resp = await self.get_posts(
page=page,
post_filter=post_filter,
ignore_invalid_messages=ignore_invalid_messages,
invalid_messages_log_level=invalid_messages_log_level,
verify_signatures=verify_signatures,
)
page += 1
for post in resp.posts:
Expand Down Expand Up @@ -178,6 +189,7 @@ async def get_messages(
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> MessagesResponse:
"""
Fetch a list of messages from the network.
Expand All @@ -187,25 +199,35 @@ async def get_messages(
:param message_filter: Filter to apply to the messages
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Verify the signatures of the messages (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

async def get_messages_iterator(
self,
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> AsyncIterable[AlephMessage]:
"""
Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates
but will always return all messages.

:param message_filter: Filter to apply to the messages
:param ignore_invalid_messages: Ignore invalid messages (Default: True)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
:param verify_signatures: Whether to verify the signatures of the messages (Default: False)
"""
page = 1
resp = None
while resp is None or len(resp.messages) > 0:
resp = await self.get_messages(
page=page,
message_filter=message_filter,
ignore_invalid_messages=ignore_invalid_messages,
invalid_messages_log_level=invalid_messages_log_level,
verify_signatures=verify_signatures,
)
page += 1
for message in resp.messages:
Expand All @@ -216,24 +238,28 @@ async def get_message(
self,
item_hash: str,
message_type: Optional[Type[GenericMessage]] = None,
verify_signature: bool = False,
) -> GenericMessage:
"""
Get a single message from its `item_hash` and perform some basic validation.

:param item_hash: Hash of the message to fetch
:param message_type: Type of message to fetch
:param verify_signature: Whether to verify the signature of the message (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

@abstractmethod
def watch_messages(
self,
message_filter: Optional[MessageFilter] = None,
verify_signatures: bool = False,
) -> AsyncIterable[AlephMessage]:
"""
Iterate over current and future matching messages asynchronously.

:param message_filter: Filter to apply to the messages
:param verify_signatures: Whether to verify the signatures of the messages (Default: False)
"""
raise NotImplementedError("Did you mean to import `AlephHttpClient`?")

Expand Down
19 changes: 17 additions & 2 deletions src/aleph/sdk/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..exceptions import FileTooLarge, ForgottenMessageError, MessageNotFoundError
from ..query.filters import MessageFilter, PostFilter
from ..query.responses import MessagesResponse, Post, PostsResponse
from ..security import verify_message_signature
from ..types import GenericMessage
from ..utils import (
Writable,
Expand Down Expand Up @@ -117,6 +118,7 @@ async def get_posts(
post_filter: Optional[PostFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> PostsResponse:
ignore_invalid_messages = (
True if ignore_invalid_messages is None else ignore_invalid_messages
Expand Down Expand Up @@ -145,12 +147,15 @@ async def get_posts(
posts: List[Post] = []
for post_raw in posts_raw:
try:
posts.append(Post.parse_obj(post_raw))
post = Post.parse_obj(post_raw)
posts.append(post)
except ValidationError as e:
if not ignore_invalid_messages:
raise e
if invalid_messages_log_level:
logger.log(level=invalid_messages_log_level, msg=e)
if verify_signatures:
verify_message_signature(post)
return PostsResponse(
posts=posts,
pagination_page=response_json["pagination_page"],
Expand Down Expand Up @@ -266,6 +271,7 @@ async def get_messages(
message_filter: Optional[MessageFilter] = None,
ignore_invalid_messages: Optional[bool] = True,
invalid_messages_log_level: Optional[int] = logging.NOTSET,
verify_signatures: bool = False,
) -> MessagesResponse:
ignore_invalid_messages = (
True if ignore_invalid_messages is None else ignore_invalid_messages
Expand Down Expand Up @@ -312,6 +318,8 @@ async def get_messages(
raise e
if invalid_messages_log_level:
logger.log(level=invalid_messages_log_level, msg=e)
if verify_signatures:
verify_message_signature(message)

return MessagesResponse(
messages=messages,
Expand All @@ -325,6 +333,7 @@ async def get_message(
self,
item_hash: str,
message_type: Optional[Type[GenericMessage]] = None,
verify_signature: bool = False,
) -> GenericMessage:
async with self.http_session.get(f"/api/v0/messages/{item_hash}") as resp:
try:
Expand All @@ -339,6 +348,8 @@ async def get_message(
f"The requested message {message_raw['item_hash']} has been forgotten by {', '.join(message_raw['forgotten_by'])}"
)
message = parse_message(message_raw["message"])
if verify_signature:
verify_message_signature(message)
if message_type:
expected_type = get_message_type_value(message_type)
if message.type != expected_type:
Expand Down Expand Up @@ -374,6 +385,7 @@ async def get_message_error(
async def watch_messages(
self,
message_filter: Optional[MessageFilter] = None,
verify_signatures: bool = False,
) -> AsyncIterable[AlephMessage]:
message_filter = message_filter or MessageFilter()
params = message_filter.as_http_params()
Expand All @@ -389,6 +401,9 @@ async def watch_messages(
break
else:
data = json.loads(msg.data)
yield parse_message(data)
message = parse_message(data)
if verify_signatures:
verify_message_signature(message)
yield message
elif msg.type == aiohttp.WSMsgType.ERROR:
break
64 changes: 64 additions & 0 deletions src/aleph/sdk/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from importlib import import_module
from typing import Callable, Dict, Optional, Union

from aleph_message.models import AlephMessage, Chain

from aleph.sdk.chains.common import get_verification_buffer
from aleph.sdk.query.responses import Post


def _try_import_verify_signature(
chain: str,
) -> Optional[
Callable[[Union[bytes, str], Union[bytes, str], Union[bytes, str]], None]
]:
"""Try to import a chain signature validator."""
try:
return import_module(f"aleph.sdk.chains.{chain}").verify_signature
except (ImportError, AttributeError):
return None


# This is a dict containing all currently available signature validators,
# indexed by their Chain abbreviation.
#
# Ex.: validators["SOL"] -> aleph.sdk.chains.solana.verify_signature()
VALIDATORS: Dict[
Chain,
Optional[Callable[[Union[bytes, str], Union[bytes, str], Union[bytes, str]], None]],
] = {
key: _try_import_verify_signature(value)
for key, value in {
# TODO: Add AVAX
Chain.ETH: "ethereum",
Chain.SOL: "sol",
Chain.CSDK: "cosmos",
Chain.DOT: "substrate",
Chain.NULS2: "nuls2",
Chain.TEZOS: "tezos",
}.items()
}


def verify_message_signature(message: Union[AlephMessage, Post]) -> None:
"""Verify the signature of a message, raise an error if invalid or unsupported.
A BadSignatureError is raised when the signature is incorrect.
A ValueError is raised when the chain is not supported or required dependencies are missing.
"""
if message.chain not in VALIDATORS:
raise ValueError(f"Chain {message.chain} is not supported.")

validator = VALIDATORS[message.chain]
if validator is None:
raise ValueError(
f"Chain {message.chain} is not installed. Install it with `aleph-sdk-python[{message.chain}]`."
)

signature = message.signature
public_key = message.sender
message = get_verification_buffer(message.dict())

# to please mypy
assert isinstance(signature, (str, bytes))

validator(signature, public_key, message)
4 changes: 4 additions & 0 deletions tests/unit/test_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test_validators_loaded():
import aleph.sdk.security as security

assert any([validator is not None for validator in security.validators.values()])
Loading