Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace attr with dataclass + fastapi.Query() for GET models #714

Merged
merged 4 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions stac_fastapi/api/stac_fastapi/api/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Api request/response models."""

import importlib.util
from dataclasses import dataclass, make_dataclass
from typing import List, Optional, Type, Union

import attr
from fastapi import Path
from fastapi import Path, Query
from pydantic import BaseModel, create_model
from stac_pydantic.shared import BBox
from typing_extensions import Annotated

from stac_fastapi.types.extension import ApiExtension
from stac_fastapi.types.rfc3339 import DateTimeType
Expand Down Expand Up @@ -37,11 +38,11 @@ def create_request_model(

mixins = mixins or []

models = [base_model] + extension_models + mixins
models = extension_models + mixins + [base_model]
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved

# Handle GET requests
if all([issubclass(m, APIRequest) for m in models]):
return attr.make_class(model_name, attrs={}, bases=tuple(models))
return make_dataclass(model_name, [], bases=tuple(models))

# Handle POST requests
elif all([issubclass(m, BaseModel) for m in models]):
Expand Down Expand Up @@ -80,34 +81,43 @@ def create_post_request_model(
)


@attr.s # type:ignore
@dataclass
class CollectionUri(APIRequest):
"""Get or delete collection."""

collection_id: str = attr.ib(default=Path(..., description="Collection ID"))
collection_id: Annotated[str, Path(description="Collection ID")]


@attr.s
class ItemUri(CollectionUri):
@dataclass
class ItemUri(APIRequest):
"""Get or delete item."""

item_id: str = attr.ib(default=Path(..., description="Item ID"))
collection_id: Annotated[str, Path(description="Collection ID")]
item_id: Annotated[str, Path(description="Item ID")]


@attr.s
@dataclass
class EmptyRequest(APIRequest):
"""Empty request."""

...


@attr.s
class ItemCollectionUri(CollectionUri):
@dataclass
class ItemCollectionUri(APIRequest):
"""Get item collection."""

limit: int = attr.ib(default=10)
bbox: Optional[BBox] = attr.ib(default=None, converter=str2bbox)
datetime: Optional[DateTimeType] = attr.ib(default=None, converter=str_to_interval)
collection_id: Annotated[str, Path(description="Collection ID")]
limit: Annotated[int, Query()] = 10
bbox: Annotated[Optional[BBox], Query()] = None
datetime: Annotated[Optional[DateTimeType], Query()] = None

def __post_init__(self):
"""convert attributes."""
if self.bbox:
self.bbox = str2bbox(self.bbox) # type: ignore
if self.datetime:
self.datetime = str_to_interval(self.datetime) # type: ignore


class POSTTokenPagination(BaseModel):
Expand All @@ -116,11 +126,11 @@ class POSTTokenPagination(BaseModel):
token: Optional[str] = None


@attr.s
@dataclass
class GETTokenPagination(APIRequest):
"""Token pagination for GET requests."""

token: Optional[str] = attr.ib(default=None)
token: Annotated[Optional[str], Query()] = None


class POSTPagination(BaseModel):
Expand All @@ -129,11 +139,11 @@ class POSTPagination(BaseModel):
page: Optional[str] = None


@attr.s
@dataclass
class GETPagination(APIRequest):
"""Page based pagination for GET requests."""

page: Optional[str] = attr.ib(default=None)
page: Annotated[Optional[str], Query()] = None


# Test for ORJSON and use it rather than stdlib JSON where supported
Expand Down
30 changes: 26 additions & 4 deletions stac_fastapi/api/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json

import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from pydantic import ValidationError

from stac_fastapi.api.models import create_get_request_model, create_post_request_model
Expand All @@ -26,13 +28,33 @@ def test_create_get_request_model():
datetime="2020-01-01T00:00:00Z",
limit=10,
filter="test==test",
# FIXME: https://github.com/stac-utils/stac-fastapi/issues/638
# hyphen aliases are not properly working
# **{"filter-crs": "epsg:4326", "filter-lang": "cql2-text"},
filter_crs="epsg:4326",
filter_lang="cql2-text",
)

assert model.collections == ["test1", "test2"]
# assert model.filter_crs == "epsg:4326"
assert model.filter_crs == "epsg:4326"

app = FastAPI()

@app.get("/test")
def route(model=Depends(request_model)):
return model

with TestClient(app) as client:
resp = client.get(
"/test",
params={
"collections": "test1,test2",
"filter-crs": "epsg:4326",
"filter-lang": "cql2-text",
},
)
assert resp.status_code == 200
response_dict = resp.json()
assert response_dict["collections"] == ["test1", "test2"]
assert response_dict["filter_crs"] == "epsg:4326"
assert response_dict["filter_lang"] == "cql2-text"


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
"""Request model for the Aggregation extension."""

from dataclasses import dataclass
from typing import List, Optional, Union

import attr
from fastapi import Query
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from stac_fastapi.extensions.core.filter.request import (
FilterExtensionGetRequest,
FilterExtensionPostRequest,
)
from stac_fastapi.types.search import BaseSearchGetRequest, BaseSearchPostRequest
from stac_fastapi.types.search import APIRequest


@attr.s
class AggregationExtensionGetRequest(BaseSearchGetRequest, FilterExtensionGetRequest):
@dataclass
class AggregationExtensionGetRequest(APIRequest):
"""Aggregation Extension GET request model."""

aggregations: Optional[str] = attr.ib(default=None)
aggregations: Annotated[Optional[str], Query()] = None


class AggregationExtensionPostRequest(BaseSearchPostRequest, FilterExtensionPostRequest):
class AggregationExtensionPostRequest(BaseModel):
"""Aggregation Extension POST request model."""

aggregations: Optional[Union[str, List[str]]] = attr.ib(default=None)
aggregations: Optional[Union[str, List[str]]] = Field(default=None)
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Request models for the fields extension."""

import warnings
from dataclasses import dataclass
from typing import Dict, Optional, Set

import attr
from fastapi import Query
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from stac_fastapi.types.search import APIRequest, str2list

Expand Down Expand Up @@ -68,11 +70,16 @@ def filter_fields(self) -> Dict:
}


@attr.s
@dataclass
class FieldsExtensionGetRequest(APIRequest):
"""Additional fields for the GET request."""

fields: Optional[str] = attr.ib(default=None, converter=str2list)
fields: Annotated[Optional[str], Query()] = None

def __post_init__(self):
"""convert attributes."""
if self.fields:
self.fields = str2list(self.fields) # type: ignore


class FieldsExtensionPostRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
"""Filter extension request models."""

from dataclasses import dataclass
from typing import Any, Dict, Literal, Optional

import attr
from fastapi import Query
from pydantic import BaseModel, Field
from typing_extensions import Annotated

from stac_fastapi.types.search import APIRequest

FilterLang = Literal["cql-json", "cql2-json", "cql2-text"]


@attr.s
@dataclass
class FilterExtensionGetRequest(APIRequest):
"""Filter extension GET request model."""

filter: Optional[str] = attr.ib(default=None)
filter_crs: Optional[str] = Field(alias="filter-crs", default=None)
filter_lang: Optional[FilterLang] = Field(alias="filter-lang", default="cql2-text")
filter: Annotated[Optional[str], Query()] = None
filter_crs: Annotated[Optional[str], Query(alias="filter-crs")] = None
filter_lang: Annotated[Optional[FilterLang], Query(alias="filter-lang")] = "cql2-text"


class FilterExtensionPostRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Request model for the Query extension."""

from dataclasses import dataclass
from typing import Any, Dict, Optional

import attr
from fastapi import Query
from pydantic import BaseModel
from typing_extensions import Annotated

from stac_fastapi.types.search import APIRequest


@attr.s
@dataclass
class QueryExtensionGetRequest(APIRequest):
"""Query Extension GET request model."""

query: Optional[str] = attr.ib(default=None)
query: Annotated[Optional[str], Query()] = None


class QueryExtensionPostRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
# encoding: utf-8
"""Request model for the Sort Extension."""

from dataclasses import dataclass
from typing import List, Optional

import attr
from fastapi import Query
from pydantic import BaseModel
from stac_pydantic.api.extensions.sort import SortExtension as PostSortModel
from typing_extensions import Annotated

from stac_fastapi.types.search import APIRequest, str2list


@attr.s
@dataclass
class SortExtensionGetRequest(APIRequest):
"""Sortby Parameter for GET requests."""

sortby: Optional[str] = attr.ib(default=None, converter=str2list)
sortby: Annotated[Optional[str], Query()] = None

def __post_init__(self):
"""convert attributes."""
if self.sortby:
self.sortby = str2list(self.sortby) # type: ignore


class SortExtensionPostRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Transaction extension."""

from dataclasses import dataclass
from typing import List, Optional, Type, Union

import attr
from fastapi import APIRouter, Body, FastAPI
from stac_pydantic import Collection, Item, ItemCollection
from stac_pydantic.shared import MimeTypes
from starlette.responses import JSONResponse, Response
from typing_extensions import Annotated

from stac_fastapi.api.models import CollectionUri, ItemUri
from stac_fastapi.api.routes import create_async_endpoint
Expand All @@ -15,25 +17,25 @@
from stac_fastapi.types.extension import ApiExtension


@attr.s
@dataclass
class PostItem(CollectionUri):
"""Create Item."""

item: Union[Item, ItemCollection] = attr.ib(default=Body(None))
item: Annotated[Union[Item, ItemCollection], Body()] = None


@attr.s
@dataclass
class PutItem(ItemUri):
"""Update Item."""

item: Item = attr.ib(default=Body(None))
item: Annotated[Item, Body()] = None


@attr.s
@dataclass
class PutCollection(CollectionUri):
"""Update Collection."""

collection: Collection = attr.ib(default=Body(None))
collection: Annotated[Collection, Body()] = None


@attr.s
Expand Down
Loading
Loading