Skip to content

Commit

Permalink
Feature: return errors for invalid post API queries (#344)
Browse files Browse the repository at this point in the history
Problem: the posts.json endpoint is too permissive and allows users
to specify invalid hashes, time filters, pagination, etc.

Solution: detect these cases and return a 422 error.

Replaced the validation code by a Pydantic model.

Breaking changes:
* The "endDate" field is now considered as exclusive.

Moreover, a 422 error code will now be returned in the following
situations, where the previous implementation would simply return a 200:

* if an invalid item hash (=not a hexadecimal sha256, CIDv0 or CIDv1)
  is specified in the "hashes" or "contentHashes" field.
* if the "endDate" field is lower than the "startDate" field.
* if "endDate" or "startDate" are negative.
* if pagination parameters ("page" and "pagination") are negative.
  • Loading branch information
odesenfans committed Jan 9, 2023
1 parent 0067b93 commit d9fca73
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 69 deletions.
4 changes: 2 additions & 2 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from aleph.model.messages import CappedMessage, Message
from aleph.web.controllers.utils import (
DEFAULT_MESSAGES_PER_PAGE,
DEFAULT_PAGE,
LIST_FIELD_SEPARATOR,
Pagination,
cond_output,
Expand All @@ -20,8 +22,6 @@
LOGGER = logging.getLogger(__name__)


DEFAULT_MESSAGES_PER_PAGE = 20
DEFAULT_PAGE = 1
DEFAULT_WS_HISTORY = 10


Expand Down
196 changes: 132 additions & 64 deletions src/aleph/web/controllers/posts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,129 @@
from typing import Optional, List, Mapping, Any

from aiohttp import web
from aleph_message.models import ItemHash
from pydantic import BaseModel, Field, root_validator, validator, ValidationError

from aleph.model.messages import Message, get_merged_posts
from aleph.web.controllers.utils import Pagination, cond_output, prepare_date_filters
from aleph.web.controllers.utils import (
DEFAULT_MESSAGES_PER_PAGE,
DEFAULT_PAGE,
LIST_FIELD_SEPARATOR,
Pagination,
cond_output,
make_date_filters,
)


class PostQueryParams(BaseModel):
addresses: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'sender' field."
)
hashes: Optional[List[ItemHash]] = Field(
default=None, description="Accepted values for the 'item_hash' field."
)
refs: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'content.ref' field."
)
post_types: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'content.type' field."
)
tags: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'content.content.tag' field."
)
channels: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'channel' field."
)
start_date: float = Field(
default=0,
ge=0,
alias="startDate",
description="Start date timestamp. If specified, only messages with "
"a time field greater or equal to this value will be returned.",
)
end_date: float = Field(
default=0,
ge=0,
alias="endDate",
description="End date timestamp. If specified, only messages with "
"a time field lower than this value will be returned.",
)
pagination: int = Field(
default=DEFAULT_MESSAGES_PER_PAGE,
ge=0,
description="Maximum number of messages to return. Specifying 0 removes this limit.",
)
page: int = Field(
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
)

@root_validator
def validate_field_dependencies(cls, values):
start_date = values.get("start_date")
end_date = values.get("end_date")
if start_date and end_date and (end_date < start_date):
raise ValueError("end date cannot be lower than start date.")
return values

@validator(
"addresses",
"hashes",
"refs",
"post_types",
"channels",
"tags",
pre=True,
)
def split_str(cls, v):
if isinstance(v, str):
return v.split(LIST_FIELD_SEPARATOR)
return v

def to_filter_list(self) -> List[Mapping[str, Any]]:

filters: List[Mapping[str, Any]] = []

if self.addresses is not None:
filters.append(
{"content.address": {"$in": self.addresses}},
)
if self.post_types is not None:
filters.append({"content.type": {"$in": self.post_types}})
if self.refs is not None:
filters.append({"content.ref": {"$in": self.refs}})
if self.tags is not None:
filters.append({"content.content.tags": {"$elemMatch": {"$in": self.tags}}})
if self.hashes is not None:
filters.append(
{
"$or": [
{"item_hash": {"$in": self.hashes}},
{"tx_hash": {"$in": self.hashes}},
]
}
)
if self.channels is not None:
filters.append({"channel": {"$in": self.channels}})

date_filters = make_date_filters(
start=self.start_date, end=self.end_date, filter_key="time"
)
if date_filters:
filters.append(date_filters)

return filters

def to_mongodb_filters(self) -> Mapping[str, Any]:
filters = self.to_filter_list()
return self._make_and_filter(filters)

@staticmethod
def _make_and_filter(filters: List[Mapping[str, Any]]) -> Mapping[str, Any]:
and_filter: Mapping[str, Any] = {}
if filters:
and_filter = {"$and": filters} if len(filters) > 1 else filters[0]

return and_filter


async def view_posts_list(request):
Expand All @@ -8,72 +132,16 @@ async def view_posts_list(request):
"""

find_filters = {}
filters = [
# {'type': request.query.get('msgType', 'POST')}
]

query_string = request.query_string
addresses = request.query.get("addresses", None)
if addresses is not None:
addresses = addresses.split(",")

refs = request.query.get("refs", None)
if refs is not None:
refs = refs.split(",")

post_types = request.query.get("types", None)
if post_types is not None:
post_types = post_types.split(",")

tags = request.query.get("tags", None)
if tags is not None:
tags = tags.split(",")

hashes = request.query.get("hashes", None)
if hashes is not None:
hashes = hashes.split(",")

channels = request.query.get("channels", None)
if channels is not None:
channels = channels.split(",")

date_filters = prepare_date_filters(request, "time")

if addresses is not None:
filters.append({"content.address": {"$in": addresses}})

if post_types is not None:
filters.append({"content.type": {"$in": post_types}})

if refs is not None:
filters.append({"content.ref": {"$in": refs}})

if tags is not None:
filters.append({"content.content.tags": {"$elemMatch": {"$in": tags}}})

if hashes is not None:
filters.append(
{"$or": [{"item_hash": {"$in": hashes}}, {"tx_hash": {"$in": hashes}}]}
)

if channels is not None:
filters.append({"channel": {"$in": channels}})

if date_filters is not None:
filters.append(date_filters)

if len(filters) > 0:
find_filters = {"$and": filters} if len(filters) > 1 else filters[0]
try:
query_params = PostQueryParams.parse_obj(request.query)
except ValidationError as e:
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))

(
pagination_page,
pagination_per_page,
pagination_skip,
) = Pagination.get_pagination_params(request)
if pagination_per_page is None:
pagination_per_page = 0
if pagination_skip is None:
pagination_skip = 0
pagination_page = query_params.page
pagination_per_page = query_params.pagination
pagination_skip = (query_params.page - 1) * query_params.pagination

posts = [
msg
Expand Down
6 changes: 3 additions & 3 deletions src/aleph/web/controllers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from aiohttp import web
from bson import json_util

PER_PAGE = 20
PER_PAGE_SUMMARY = 50
DEFAULT_MESSAGES_PER_PAGE = 20
DEFAULT_PAGE = 1
LIST_FIELD_SEPARATOR = ","


Expand All @@ -15,7 +15,7 @@ class Pagination(object):
def get_pagination_params(request):
pagination_page = int(request.match_info.get("page", "1"))
pagination_page = int(request.query.get("page", pagination_page))
pagination_param = int(request.query.get("pagination", PER_PAGE))
pagination_param = int(request.query.get("pagination", DEFAULT_MESSAGES_PER_PAGE))
with_pagination = pagination_param != 0

if pagination_page < 1:
Expand Down

0 comments on commit d9fca73

Please sign in to comment.