diff --git a/CHANGES.md b/CHANGES.md index ea020992..011a7fde 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,7 @@ ### Removed * Removed the Filter Extension depenency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716) +* Removed `add_middleware` method in `StacApi` object and let starlette handle the middleware stack creation [721](https://github.com/stac-utils/stac-fastapi/pull/721) ## [3.0.0a3] - 2024-06-13 diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 5fe7f9d0..b4f5125f 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -439,11 +439,6 @@ def add_route_dependencies( """ return add_route_dependencies(self.app.router.routes, scopes, dependencies) - def add_middleware(self, middleware: Middleware): - """Add a middleware class to the application.""" - self.app.user_middleware.insert(0, middleware) - self.app.middleware_stack = self.app.build_middleware_stack() - def __attrs_post_init__(self): """Post-init hook. @@ -484,7 +479,7 @@ def __attrs_post_init__(self): # add middlewares for middleware in self.middlewares: - self.add_middleware(middleware) + self.app.user_middleware.insert(0, middleware) # customize route dependencies for scopes, dependencies in self.route_dependencies: diff --git a/stac_fastapi/api/tests/test_middleware.py b/stac_fastapi/api/tests/test_middleware.py index 041dc410..00e7f803 100644 --- a/stac_fastapi/api/tests/test_middleware.py +++ b/stac_fastapi/api/tests/test_middleware.py @@ -1,6 +1,8 @@ from unittest import mock import pytest +from fastapi import Request +from fastapi.responses import JSONResponse from starlette.applications import Starlette from starlette.testclient import TestClient @@ -166,3 +168,31 @@ def test_cors_middleware(test_client): resp = test_client.get("/_mgmt/ping", headers={"Origin": "http://netloc"}) assert resp.status_code == 200 assert resp.headers["access-control-allow-origin"] == "*" + + +def test_middleware_stack(): + stac_api = StacApi( + settings=ApiSettings(), client=mock.create_autospec(BaseCoreClient) + ) + + def exception_handler(request: Request, exc: Exception) -> JSONResponse: + return JSONResponse( + status_code=400, + content={"customerrordetail": "yoo", "body": "yo"}, + ) + + class CustomException(Exception): + "Custom Exception" + + pass + + stac_api.app.add_exception_handler(CustomException, exception_handler) + + @stac_api.app.get("/error") + def error_endpoint(): + raise CustomException("got you!") + + with TestClient(stac_api.app) as client: + resp = client.get("/error") + assert resp.status_code == 400 + assert resp.json()["customerrordetail"] == "yoo"