From 8075fc9af0cd5165e4b7661f1c6796cffedc30c7 Mon Sep 17 00:00:00 2001 From: James Fisher <85769594+jamesfisher-gis@users.noreply.github.com> Date: Tue, 11 Jun 2024 03:20:22 -0400 Subject: [PATCH] Aggregation Extension (#684) * initial commit * aggregation extension and tests * clean up * update changelog * Search and Filter extension * AggregationCollection * AggregationCollection classes * test classes * AggregationCollection literal * aggregation post model * docstring fix * linting * TypedDict import * move aggregation client and types into extensions * linting --- CHANGES.md | 4 + stac_fastapi/api/stac_fastapi/api/config.py | 1 + .../stac_fastapi/extensions/core/__init__.py | 2 + .../extensions/core/aggregation/__init__.py | 5 + .../core/aggregation/aggregation.py | 111 +++++++++++++++ .../extensions/core/aggregation/client.py | 131 ++++++++++++++++++ .../extensions/core/aggregation/request.py | 24 ++++ .../extensions/core/aggregation/types.py | 36 +++++ .../extensions/tests/test_aggregation.py | 102 ++++++++++++++ stac_fastapi/types/stac_fastapi/types/core.py | 38 +++++ 10 files changed, 454 insertions(+) create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/__init__.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/aggregation.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/client.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py create mode 100644 stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/types.py create mode 100644 stac_fastapi/extensions/tests/test_aggregation.py diff --git a/CHANGES.md b/CHANGES.md index a75f1da8a..d6499fb83 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,10 @@ ## [Unreleased] - TBD +### Added + +* Add base support for the Aggregation extension [#684](https://github.com/stac-utils/stac-fastapi/pull/684) + ### Changed * moved `AsyncBaseFiltersClient` and `BaseFiltersClient` classes in `stac_fastapi.extensions.core.filter.client` submodule ([#704](https://github.com/stac-utils/stac-fastapi/pull/704)) diff --git a/stac_fastapi/api/stac_fastapi/api/config.py b/stac_fastapi/api/stac_fastapi/api/config.py index 3918421ff..20a7b4af5 100644 --- a/stac_fastapi/api/stac_fastapi/api/config.py +++ b/stac_fastapi/api/stac_fastapi/api/config.py @@ -18,6 +18,7 @@ class ApiExtensions(enum.Enum): query = "query" sort = "sort" transaction = "transaction" + aggregation = "aggregation" class AddOns(enum.Enum): diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py index 74f15ed0a..7e29e1fd2 100644 --- a/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/__init__.py @@ -1,5 +1,6 @@ """stac_api.extensions.core module.""" +from .aggregation import AggregationExtension from .context import ContextExtension from .fields import FieldsExtension from .filter import FilterExtension @@ -9,6 +10,7 @@ from .transaction import TransactionExtension __all__ = ( + "AggregationExtension", "ContextExtension", "FieldsExtension", "FilterExtension", diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/__init__.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/__init__.py new file mode 100644 index 000000000..2a7fc7a71 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/__init__.py @@ -0,0 +1,5 @@ +"""Aggregation extension module.""" + +from .aggregation import AggregationExtension + +__all__ = ["AggregationExtension"] diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/aggregation.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/aggregation.py new file mode 100644 index 000000000..c6e892914 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/aggregation.py @@ -0,0 +1,111 @@ +"""Aggregation Extension.""" +from enum import Enum +from typing import List, Union + +import attr +from fastapi import APIRouter, FastAPI + +from stac_fastapi.api.models import CollectionUri, EmptyRequest +from stac_fastapi.api.routes import create_async_endpoint +from stac_fastapi.types.extension import ApiExtension + +from .client import AsyncBaseAggregationClient, BaseAggregationClient +from .request import AggregationExtensionGetRequest, AggregationExtensionPostRequest + + +class AggregationConformanceClasses(str, Enum): + """Conformance classes for the Aggregation extension. + + See + https://github.com/stac-api-extensions/aggregation + """ + + AGGREGATION = "https://api.stacspec.org/v0.3.0/aggregation" + + +@attr.s +class AggregationExtension(ApiExtension): + """Aggregation Extension. + + The purpose of the Aggregation Extension is to provide an endpoint similar to + the Search endpoint (/search), but which will provide aggregated information + on matching Items rather than the Items themselves. This is highly influenced + by the Elasticsearch and OpenSearch aggregation endpoint, but with a more + regular structure for responses. + + The Aggregation extension adds several endpoints which allow the retrieval of + available aggregation fields and aggregation buckets based on a seearch query: + GET /aggregations + POST /aggregations + GET /collections/{collection_id}/aggregations + POST /collections/{collection_id}/aggregations + GET /aggregate + POST /aggregate + GET /collections/{collection_id}/aggregate + POST /collections/{collection_id}/aggregate + + https://github.com/stac-api-extensions/aggregation/blob/main/README.md + + Attributes: + conformance_classes: Conformance classes provided by the extension + """ + + GET = AggregationExtensionGetRequest + POST = AggregationExtensionPostRequest + + client: Union[AsyncBaseAggregationClient, BaseAggregationClient] = attr.ib( + factory=BaseAggregationClient + ) + + conformance_classes: List[str] = attr.ib( + default=[AggregationConformanceClasses.AGGREGATION] + ) + router: APIRouter = attr.ib(factory=APIRouter) + + def register(self, app: FastAPI) -> None: + """Register the extension with a FastAPI application. + + Args: + app: target FastAPI application. + + Returns: + None + """ + self.router.prefix = app.state.router_prefix + self.router.add_api_route( + name="Aggregations", + path="/aggregations", + methods=["GET", "POST"], + endpoint=create_async_endpoint(self.client.get_aggregations, EmptyRequest), + ) + self.router.add_api_route( + name="Collection Aggregations", + path="/collections/{collection_id}/aggregations", + methods=["GET", "POST"], + endpoint=create_async_endpoint(self.client.get_aggregations, CollectionUri), + ) + self.router.add_api_route( + name="Aggregate", + path="/aggregate", + methods=["GET"], + endpoint=create_async_endpoint(self.client.aggregate, self.GET), + ) + self.router.add_api_route( + name="Aggregate", + path="/aggregate", + methods=["POST"], + endpoint=create_async_endpoint(self.client.aggregate, self.POST), + ) + self.router.add_api_route( + name="Collection Aggregate", + path="/collections/{collection_id}/aggregate", + methods=["GET"], + endpoint=create_async_endpoint(self.client.aggregate, self.GET), + ) + self.router.add_api_route( + name="Collection Aggregate", + path="/collections/{collection_id}/aggregate", + methods=["POST"], + endpoint=create_async_endpoint(self.client.aggregate, self.POST), + ) + app.include_router(self.router, tags=["Aggregation Extension"]) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/client.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/client.py new file mode 100644 index 000000000..23d90fb28 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/client.py @@ -0,0 +1,131 @@ +"""Aggregation extensions clients.""" + +import abc +from typing import List, Optional, Union + +import attr +from geojson_pydantic.geometries import Geometry +from stac_pydantic.shared import BBox + +from stac_fastapi.types.rfc3339 import DateTimeType + +from .types import Aggregation, AggregationCollection + + +@attr.s +class BaseAggregationClient(abc.ABC): + """Defines a pattern for implementing the STAC aggregation extension.""" + + # BUCKET = Bucket + # AGGREGAION = Aggregation + # AGGREGATION_COLLECTION = AggregationCollection + + def get_aggregations( + self, collection_id: Optional[str] = None, **kwargs + ) -> AggregationCollection: + """Get the aggregations available for the given collection_id. + + If collection_id is None, returns the available aggregations over all + collections. + """ + return AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(name="total_count", data_type="integer")], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) + + def aggregate( + self, collection_id: Optional[str] = None, **kwargs + ) -> AggregationCollection: + """Return the aggregation buckets for a given search result""" + return AggregationCollection( + type="AggregationCollection", + aggregations=[], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) + + +@attr.s +class AsyncBaseAggregationClient(abc.ABC): + """Defines an async pattern for implementing the STAC aggregation extension.""" + + # BUCKET = Bucket + # AGGREGAION = Aggregation + # AGGREGATION_COLLECTION = AggregationCollection + + async def get_aggregations( + self, collection_id: Optional[str] = None, **kwargs + ) -> AggregationCollection: + """Get the aggregations available for the given collection_id. + + If collection_id is None, returns the available aggregations over all + collections. + """ + return AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(name="total_count", data_type="integer")], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) + + async def aggregate( + self, + collection_id: Optional[str] = None, + aggregations: Optional[Union[str, List[str]]] = None, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[BBox] = None, + intersects: Optional[Geometry] = None, + datetime: Optional[DateTimeType] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> AggregationCollection: + """Return the aggregation buckets for a given search result""" + return AggregationCollection( + type="AggregationCollection", + aggregations=[], + links=[ + { + "rel": "root", + "type": "application/json", + "href": "https://example.org/", + }, + { + "rel": "self", + "type": "application/json", + "href": "https://example.org/aggregations", + }, + ], + ) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py new file mode 100644 index 000000000..fcab3323f --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/request.py @@ -0,0 +1,24 @@ +"""Request model for the Aggregation extension.""" + +from typing import List, Optional, Union + +import attr + +from stac_fastapi.extensions.core.filter.request import ( + FilterExtensionGetRequest, + FilterExtensionPostRequest, +) +from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest + + +@attr.s +class AggregationExtensionGetRequest(BaseSearchGetRequest, FilterExtensionGetRequest): + """Aggregation Extension GET request model.""" + + aggregations: Optional[str] = attr.ib(default=None) + + +class AggregationExtensionPostRequest(BaseSearchPostRequest, FilterExtensionPostRequest): + """Aggregation Extension POST request model.""" + + aggregations: Optional[Union[str, List[str]]] = attr.ib(default=None) diff --git a/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/types.py b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/types.py new file mode 100644 index 000000000..428b65225 --- /dev/null +++ b/stac_fastapi/extensions/stac_fastapi/extensions/core/aggregation/types.py @@ -0,0 +1,36 @@ +"""Aggregation Extension types.""" + +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import Field +from typing_extensions import TypedDict + +from stac_fastapi.types.rfc3339 import DateTimeType + + +class Bucket(TypedDict, total=False): + """A STAC aggregation bucket.""" + + key: str + data_type: str + frequency: Optional[Dict] = None + _from: Optional[Union[int, float]] = Field(alias="from", default=None) + to: Optional[Optional[Union[int, float]]] = None + + +class Aggregation(TypedDict, total=False): + """A STAC aggregation.""" + + name: str + data_type: str + buckets: Optional[List[Bucket]] = None + overflow: Optional[int] = None + value: Optional[Union[str, int, DateTimeType]] = None + + +class AggregationCollection(TypedDict, total=False): + """STAC Item Aggregation Collection.""" + + type: Literal["AggregationCollection"] + aggregations: List[Aggregation] + links: List[Dict[str, Any]] diff --git a/stac_fastapi/extensions/tests/test_aggregation.py b/stac_fastapi/extensions/tests/test_aggregation.py new file mode 100644 index 000000000..c96e316ae --- /dev/null +++ b/stac_fastapi/extensions/tests/test_aggregation.py @@ -0,0 +1,102 @@ +from typing import Iterator + +import pytest +from starlette.testclient import TestClient + +from stac_fastapi.api.app import StacApi +from stac_fastapi.extensions.core import AggregationExtension +from stac_fastapi.extensions.core.aggregation.client import BaseAggregationClient +from stac_fastapi.extensions.core.aggregation.types import ( + Aggregation, + AggregationCollection, +) +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient + + +class DummyCoreClient(BaseCoreClient): + def all_collections(self, *args, **kwargs): + raise NotImplementedError + + def get_collection(self, *args, **kwargs): + raise NotImplementedError + + def get_item(self, *args, **kwargs): + raise NotImplementedError + + def get_search(self, *args, **kwargs): + raise NotImplementedError + + def post_search(self, *args, **kwargs): + raise NotImplementedError + + def item_collection(self, *args, **kwargs): + raise NotImplementedError + + +def test_get_aggregations(client: TestClient) -> None: + response = client.get("/aggregations") + assert response.is_success, response.text + assert response.json()["aggregations"] == [ + {"name": "total_count", "data_type": "integer"} + ] + assert AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(**response.json()["aggregations"][0])], + ) + + +def test_get_aggregate(client: TestClient) -> None: + response = client.get("/aggregate") + assert response.is_success, response.text + assert response.json()["aggregations"] == [] + assert AggregationCollection( + type="AggregationCollection", aggregations=response.json()["aggregations"] + ) + + +def test_post_aggregations(client: TestClient) -> None: + response = client.post("/aggregations") + assert response.is_success, response.text + assert response.json()["aggregations"] == [ + {"name": "total_count", "data_type": "integer"} + ] + assert AggregationCollection( + type="AggregationCollection", + aggregations=[Aggregation(**response.json()["aggregations"][0])], + ) + + +def test_post_aggregate(client: TestClient) -> None: + response = client.post("/aggregate", content="{}") + assert response.is_success, response.text + assert response.json()["aggregations"] == [] + assert AggregationCollection( + type="AggregationCollection", aggregations=response.json()["aggregations"] + ) + + +@pytest.fixture +def client( + core_client: DummyCoreClient, aggregations_client: BaseAggregationClient +) -> Iterator[TestClient]: + settings = ApiSettings() + api = StacApi( + settings=settings, + client=core_client, + extensions=[ + AggregationExtension(client=aggregations_client), + ], + ) + with TestClient(api.app) as client: + yield client + + +@pytest.fixture +def core_client() -> DummyCoreClient: + return DummyCoreClient() + + +@pytest.fixture +def aggregations_client() -> BaseAggregationClient: + return BaseAggregationClient() diff --git a/stac_fastapi/types/stac_fastapi/types/core.py b/stac_fastapi/types/stac_fastapi/types/core.py index 4cdda49e0..003a765ed 100644 --- a/stac_fastapi/types/stac_fastapi/types/core.py +++ b/stac_fastapi/types/stac_fastapi/types/core.py @@ -398,6 +398,25 @@ def landing_page(self, **kwargs) -> stac.LandingPage: } ) + # Add Aggregation links + if self.extension_is_enabled("AggregationExtension"): + landing_page["links"].extend( + [ + { + "rel": "aggregate", + "type": "application/json", + "title": "Aggregate", + "href": urljoin(base_url, "aggregate"), + }, + { + "rel": "aggregations", + "type": "application/json", + "title": "Aggregations", + "href": urljoin(base_url, "aggregations"), + }, + ] + ) + # Add Collections links collections = self.all_collections(request=kwargs["request"]) @@ -602,6 +621,25 @@ async def landing_page(self, **kwargs) -> stac.LandingPage: } ) + # Add Aggregation links + if self.extension_is_enabled("AggregationExtension"): + landing_page["links"].extend( + [ + { + "rel": "aggregate", + "type": "application/json", + "title": "Aggregate", + "href": urljoin(base_url, "aggregate"), + }, + { + "rel": "aggregations", + "type": "application/json", + "title": "Aggregations", + "href": urljoin(base_url, "aggregations"), + }, + ] + ) + # Add Collections links collections = await self.all_collections(request=kwargs["request"])