diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 45f5a9c..2c8d79b 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -17,16 +17,13 @@ create_request_model, ) from stac_fastapi.extensions.core import ( - CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, TokenPaginationExtension, TransactionExtension, ) -from stac_fastapi.extensions.core.collection_search.request import ( - BaseCollectionSearchGetRequest, -) +from stac_fastapi.extensions.core.collection_search import CollectionSearchExtension from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.pgstac.config import Settings @@ -52,10 +49,6 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } -collections_extensions_map = { - "collection_search": CollectionSearchExtension(), -} - if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): _enabled_extensions = enabled_extensions.split(",") extensions = [ @@ -63,14 +56,9 @@ for key, extension in extensions_map.items() if key in _enabled_extensions ] - collection_extensions = [ - extension - for key, extension in collections_extensions_map.items() - if key in _enabled_extensions - ] else: + _enabled_extensions = list(extensions_map.keys()) + ["collection_search"] extensions = list(extensions_map.values()) - collection_extensions = list(collections_extensions_map.values()) if any(isinstance(ext, TokenPaginationExtension) for ext in extensions): @@ -83,13 +71,10 @@ else: items_get_request_model = ItemCollectionUri -if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions): - collections_get_request_model = create_request_model( - model_name="CollectionsGetRequest", - base_model=BaseCollectionSearchGetRequest, - extensions=extensions, - request_type="GET", - ) +if "collection_search" in _enabled_extensions: + collection_extension = CollectionSearchExtension() + collections_get_request_model = collection_extension.GET + extensions.append(collection_extension) else: collections_get_request_model = EmptyRequest @@ -98,7 +83,7 @@ api = StacApi( settings=settings, - extensions=extensions + collection_extensions, + extensions=extensions, client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model,