Skip to content

Commit

Permalink
remove FieldsExtension check in StacApi (#725)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago authored Jul 5, 2024
1 parent 3c58f0f commit dbd0464
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* Removed the Filter Extension dependency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716)
* Removed `pagination_extension` attribute in `stac_fastapi.api.app.StacApi`
* Removed use of `pagination_extension` in `register_get_item_collection` function (User now need to construct the request model and pass it using `items_get_request_model` attribute)
* Removed use of `FieldsExtension` in `stac_fastapi.api.app.StacApi`. If users use `FieldsExtension`, they would have to handle overpassing the model validation step by returning a `JSONResponse` from the `post_search` and `get_search` client methods.

### Changed

Expand Down
64 changes: 64 additions & 0 deletions docs/src/migrations/v3.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,67 @@ stac=StacApi(
items_get_request_model=items_get_request_model,
)
```


## Fields extension and model validation

When using the `Fields` extension, the `/search` endpoint should be able to return `**invalid** STAC Items. This creates an issue when *model validation* is enabled at the application level.

Previously when adding the `FieldsExtension` to the extensions list and if setting output model validation, we were turning off the validation for both GET/POST `/search` endpoints. This was by-passing validation even when users were not using the `fields` options in requests.

In `stac-fastapi` v3.0, implementers will have to by-pass the *validation step* at `Client` level by returning `JSONResponse` from the `post_search` and `get_search` client methods.

```python
# before
class BadCoreClient(BaseCoreClient):
def post_search(
self, search_request: BaseSearchPostRequest, **kwargs
) -> stac.ItemCollection:
return {"not": "a proper stac item"}

def get_search(
self,
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
bbox: Optional[List[NumType]] = None,
intersects: Optional[str] = None,
datetime: Optional[Union[str, datetime]] = None,
limit: Optional[int] = 10,
**kwargs,
) -> stac.ItemCollection:
return {"not": "a proper stac item"}

# now
class BadCoreClient(BaseCoreClient):
def post_search(
self, search_request: BaseSearchPostRequest, **kwargs
) -> stac.ItemCollection:
resp = {"not": "a proper stac item"}

# if `fields` extension is enabled, then we return a JSONResponse
# to avoid Item validation
if getattr(search_request, "fields", None):
return JSONResponse(content=resp)

return resp

def get_search(
self,
collections: Optional[List[str]] = None,
ids: Optional[List[str]] = None,
bbox: Optional[List[NumType]] = None,
intersects: Optional[str] = None,
datetime: Optional[Union[str, datetime]] = None,
limit: Optional[int] = 10,
**kwargs,
) -> stac.ItemCollection:
resp = {"not": "a proper stac item"}

# if `fields` extension is enabled, then we return a JSONResponse
# to avoid Item validation
if "fields" in kwargs:
return JSONResponse(content=resp)

return resp

```
21 changes: 6 additions & 15 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
)
from stac_fastapi.api.openapi import update_openapi
from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint

# TODO: make this module not depend on `stac_fastapi.extensions`
from stac_fastapi.extensions.core import FieldsExtension
from stac_fastapi.types.config import ApiSettings, Settings
from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient
from stac_fastapi.types.extension import ApiExtension
Expand Down Expand Up @@ -225,15 +222,12 @@ def register_post_search(self):
Returns:
None
"""
fields_ext = self.get_extension(FieldsExtension)
self.router.add_api_route(
name="Search",
path="/search",
response_model=(
(api.ItemCollection if not fields_ext else None)
if self.settings.enable_response_models
else None
),
response_model=api.ItemCollection
if self.settings.enable_response_models
else None,
responses={
200: {
"content": {
Expand All @@ -257,15 +251,12 @@ def register_get_search(self):
Returns:
None
"""
fields_ext = self.get_extension(FieldsExtension)
self.router.add_api_route(
name="Search",
path="/search",
response_model=(
(api.ItemCollection if not fields_ext else None)
if self.settings.enable_response_models
else None
),
response_model=api.ItemCollection
if self.settings.enable_response_models
else None,
responses={
200: {
"content": {
Expand Down
35 changes: 12 additions & 23 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Api request/response models."""

import importlib.util
from dataclasses import dataclass, make_dataclass
from typing import List, Optional, Type, Union

Expand All @@ -19,6 +18,12 @@
str_to_interval,
)

try:
import orjson # noqa
from fastapi.responses import ORJSONResponse as JSONResponse
except ImportError: # pragma: nocover
from starlette.responses import JSONResponse


def create_request_model(
model_name="SearchGetRequest",
Expand Down Expand Up @@ -120,29 +125,13 @@ def __post_init__(self):
self.datetime = str_to_interval(self.datetime) # type: ignore


# Test for ORJSON and use it rather than stdlib JSON where supported
if importlib.util.find_spec("orjson") is not None:
from fastapi.responses import ORJSONResponse

class GeoJSONResponse(ORJSONResponse):
"""JSON with custom, vendor content-type."""

media_type = "application/geo+json"

class JSONSchemaResponse(ORJSONResponse):
"""JSON with custom, vendor content-type."""

media_type = "application/schema+json"

else:
from starlette.responses import JSONResponse
class GeoJSONResponse(JSONResponse):
"""JSON with custom, vendor content-type."""

class GeoJSONResponse(JSONResponse):
"""JSON with custom, vendor content-type."""
media_type = "application/geo+json"

media_type = "application/geo+json"

class JSONSchemaResponse(JSONResponse):
"""JSON with custom, vendor content-type."""
class JSONSchemaResponse(JSONResponse):
"""JSON with custom, vendor content-type."""

media_type = "application/schema+json"
media_type = "application/schema+json"
31 changes: 27 additions & 4 deletions stac_fastapi/api/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from stac_fastapi.api import app
from stac_fastapi.api.models import (
APIRequest,
JSONResponse,
create_get_request_model,
create_post_request_model,
)
Expand Down Expand Up @@ -206,7 +207,14 @@ class BadCoreClient(BaseCoreClient):
def post_search(
self, search_request: BaseSearchPostRequest, **kwargs
) -> stac.ItemCollection:
return {"not": "a proper stac item"}
resp = {"not": "a proper stac item"}

# if `fields` extension is enabled, then we return a JSONResponse
# to avoid Item validation
if getattr(search_request, "fields", None):
return JSONResponse(content=resp)

return resp

def get_search(
self,
Expand All @@ -218,7 +226,14 @@ def get_search(
limit: Optional[int] = 10,
**kwargs,
) -> stac.ItemCollection:
return {"not": "a proper stac item"}
resp = {"not": "a proper stac item"}

# if `fields` extension is enabled, then we return a JSONResponse
# to avoid Item validation
if "fields" in kwargs:
return JSONResponse(content=resp)

return resp

def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac.Item:
raise NotImplementedError
Expand All @@ -240,6 +255,7 @@ def item_collection(
) -> stac.ItemCollection:
raise NotImplementedError

# With FieldsExtension
test_app = app.StacApi(
settings=ApiSettings(enable_response_models=validate),
client=BadCoreClient(),
Expand All @@ -264,14 +280,18 @@ def item_collection(
},
)

# With or without validation, /search endpoints will always return 200
# because we have the `FieldsExtension` enabled, so the endpoint
# will avoid the model validation (by returning JSONResponse)
assert get_search.status_code == 200, get_search.text
assert post_search.status_code == 200, post_search.text

# Without FieldsExtension
test_app = app.StacApi(
settings=ApiSettings(enable_response_models=validate),
client=BadCoreClient(),
search_get_request_model=create_get_request_model([FieldsExtension()]),
search_post_request_model=create_post_request_model([FieldsExtension()]),
search_get_request_model=create_get_request_model([]),
search_post_request_model=create_post_request_model([]),
extensions=[],
)

Expand All @@ -290,7 +310,10 @@ def item_collection(
},
},
)

if validate:
# NOTE: the `fields` options will be ignored by fastAPI because it's
# not part of the request model, so the client should not by-pass the validation
assert get_search.status_code == 500, (
get_search.json()["code"] == "ResponseValidationError"
)
Expand Down

0 comments on commit dbd0464

Please sign in to comment.