From d9fca7320a4f7f6cf08b78c308851fe31409f7c7 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Fri, 14 Oct 2022 14:21:34 +0200 Subject: [PATCH] Feature: return errors for invalid post API queries (#344) Problem: the posts.json endpoint is too permissive and allows users to specify invalid hashes, time filters, pagination, etc. Solution: detect these cases and return a 422 error. Replaced the validation code by a Pydantic model. Breaking changes: * The "endDate" field is now considered as exclusive. Moreover, a 422 error code will now be returned in the following situations, where the previous implementation would simply return a 200: * if an invalid item hash (=not a hexadecimal sha256, CIDv0 or CIDv1) is specified in the "hashes" or "contentHashes" field. * if the "endDate" field is lower than the "startDate" field. * if "endDate" or "startDate" are negative. * if pagination parameters ("page" and "pagination") are negative. --- src/aleph/web/controllers/messages.py | 4 +- src/aleph/web/controllers/posts.py | 196 +++++++++++++++++--------- src/aleph/web/controllers/utils.py | 6 +- 3 files changed, 137 insertions(+), 69 deletions(-) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index ef56d3eef..b8630891e 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -11,6 +11,8 @@ from aleph.model.messages import CappedMessage, Message from aleph.web.controllers.utils import ( + DEFAULT_MESSAGES_PER_PAGE, + DEFAULT_PAGE, LIST_FIELD_SEPARATOR, Pagination, cond_output, @@ -20,8 +22,6 @@ LOGGER = logging.getLogger(__name__) -DEFAULT_MESSAGES_PER_PAGE = 20 -DEFAULT_PAGE = 1 DEFAULT_WS_HISTORY = 10 diff --git a/src/aleph/web/controllers/posts.py b/src/aleph/web/controllers/posts.py index 916f036e4..055553dd7 100644 --- a/src/aleph/web/controllers/posts.py +++ b/src/aleph/web/controllers/posts.py @@ -1,5 +1,129 @@ +from typing import Optional, List, Mapping, Any + +from aiohttp import web +from aleph_message.models import ItemHash +from pydantic import BaseModel, Field, root_validator, validator, ValidationError + from aleph.model.messages import Message, get_merged_posts -from aleph.web.controllers.utils import Pagination, cond_output, prepare_date_filters +from aleph.web.controllers.utils import ( + DEFAULT_MESSAGES_PER_PAGE, + DEFAULT_PAGE, + LIST_FIELD_SEPARATOR, + Pagination, + cond_output, + make_date_filters, +) + + +class PostQueryParams(BaseModel): + addresses: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'sender' field." + ) + hashes: Optional[List[ItemHash]] = Field( + default=None, description="Accepted values for the 'item_hash' field." + ) + refs: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'content.ref' field." + ) + post_types: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'content.type' field." + ) + tags: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'content.content.tag' field." + ) + channels: Optional[List[str]] = Field( + default=None, description="Accepted values for the 'channel' field." + ) + start_date: float = Field( + default=0, + ge=0, + alias="startDate", + description="Start date timestamp. If specified, only messages with " + "a time field greater or equal to this value will be returned.", + ) + end_date: float = Field( + default=0, + ge=0, + alias="endDate", + description="End date timestamp. If specified, only messages with " + "a time field lower than this value will be returned.", + ) + pagination: int = Field( + default=DEFAULT_MESSAGES_PER_PAGE, + ge=0, + description="Maximum number of messages to return. Specifying 0 removes this limit.", + ) + page: int = Field( + default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1." + ) + + @root_validator + def validate_field_dependencies(cls, values): + start_date = values.get("start_date") + end_date = values.get("end_date") + if start_date and end_date and (end_date < start_date): + raise ValueError("end date cannot be lower than start date.") + return values + + @validator( + "addresses", + "hashes", + "refs", + "post_types", + "channels", + "tags", + pre=True, + ) + def split_str(cls, v): + if isinstance(v, str): + return v.split(LIST_FIELD_SEPARATOR) + return v + + def to_filter_list(self) -> List[Mapping[str, Any]]: + + filters: List[Mapping[str, Any]] = [] + + if self.addresses is not None: + filters.append( + {"content.address": {"$in": self.addresses}}, + ) + if self.post_types is not None: + filters.append({"content.type": {"$in": self.post_types}}) + if self.refs is not None: + filters.append({"content.ref": {"$in": self.refs}}) + if self.tags is not None: + filters.append({"content.content.tags": {"$elemMatch": {"$in": self.tags}}}) + if self.hashes is not None: + filters.append( + { + "$or": [ + {"item_hash": {"$in": self.hashes}}, + {"tx_hash": {"$in": self.hashes}}, + ] + } + ) + if self.channels is not None: + filters.append({"channel": {"$in": self.channels}}) + + date_filters = make_date_filters( + start=self.start_date, end=self.end_date, filter_key="time" + ) + if date_filters: + filters.append(date_filters) + + return filters + + def to_mongodb_filters(self) -> Mapping[str, Any]: + filters = self.to_filter_list() + return self._make_and_filter(filters) + + @staticmethod + def _make_and_filter(filters: List[Mapping[str, Any]]) -> Mapping[str, Any]: + and_filter: Mapping[str, Any] = {} + if filters: + and_filter = {"$and": filters} if len(filters) > 1 else filters[0] + + return and_filter async def view_posts_list(request): @@ -8,72 +132,16 @@ async def view_posts_list(request): """ find_filters = {} - filters = [ - # {'type': request.query.get('msgType', 'POST')} - ] - query_string = request.query_string - addresses = request.query.get("addresses", None) - if addresses is not None: - addresses = addresses.split(",") - - refs = request.query.get("refs", None) - if refs is not None: - refs = refs.split(",") - - post_types = request.query.get("types", None) - if post_types is not None: - post_types = post_types.split(",") - - tags = request.query.get("tags", None) - if tags is not None: - tags = tags.split(",") - - hashes = request.query.get("hashes", None) - if hashes is not None: - hashes = hashes.split(",") - - channels = request.query.get("channels", None) - if channels is not None: - channels = channels.split(",") - - date_filters = prepare_date_filters(request, "time") - - if addresses is not None: - filters.append({"content.address": {"$in": addresses}}) - - if post_types is not None: - filters.append({"content.type": {"$in": post_types}}) - - if refs is not None: - filters.append({"content.ref": {"$in": refs}}) - - if tags is not None: - filters.append({"content.content.tags": {"$elemMatch": {"$in": tags}}}) - - if hashes is not None: - filters.append( - {"$or": [{"item_hash": {"$in": hashes}}, {"tx_hash": {"$in": hashes}}]} - ) - - if channels is not None: - filters.append({"channel": {"$in": channels}}) - - if date_filters is not None: - filters.append(date_filters) - if len(filters) > 0: - find_filters = {"$and": filters} if len(filters) > 1 else filters[0] + try: + query_params = PostQueryParams.parse_obj(request.query) + except ValidationError as e: + raise web.HTTPUnprocessableEntity(body=e.json(indent=4)) - ( - pagination_page, - pagination_per_page, - pagination_skip, - ) = Pagination.get_pagination_params(request) - if pagination_per_page is None: - pagination_per_page = 0 - if pagination_skip is None: - pagination_skip = 0 + pagination_page = query_params.page + pagination_per_page = query_params.pagination + pagination_skip = (query_params.page - 1) * query_params.pagination posts = [ msg diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index f3221ba7f..5a5804ea0 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -5,8 +5,8 @@ from aiohttp import web from bson import json_util -PER_PAGE = 20 -PER_PAGE_SUMMARY = 50 +DEFAULT_MESSAGES_PER_PAGE = 20 +DEFAULT_PAGE = 1 LIST_FIELD_SEPARATOR = "," @@ -15,7 +15,7 @@ class Pagination(object): def get_pagination_params(request): pagination_page = int(request.match_info.get("page", "1")) pagination_page = int(request.query.get("page", pagination_page)) - pagination_param = int(request.query.get("pagination", PER_PAGE)) + pagination_param = int(request.query.get("pagination", DEFAULT_MESSAGES_PER_PAGE)) with_pagination = pagination_param != 0 if pagination_page < 1: