Skip to content

Commit

Permalink
fix test and modify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Boris-vodka committed Sep 11, 2024
1 parent b6d6ff8 commit e42f80e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
12 changes: 7 additions & 5 deletions src/scholarag/app/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def set_cache(
body: Any,
settings: Settings,
request: Request,
redis: AsyncRedis[Any],
redis: "AsyncRedis[Any]",
request_key: str,
) -> None:
"""Format and set the response in Redis cache.
Expand All @@ -178,7 +178,6 @@ async def set_cache(
"headers": dict(endpoint_response.headers),
"media_type": endpoint_response.media_type,
}
response["headers"]["X-fastapi-cache"] = "Miss"

request_body = await request.body()
cached = {
Expand Down Expand Up @@ -298,7 +297,10 @@ async def get_and_set_cache(
else:
response = await call_next(request)
if request.headers.get("cache-control") not in ("no-cache", "no-store"):
# FastAPI uses 'StremingResponse' everywhere under the hood.
# moved from set_cache, easier for streaming.
response_header = {**response.headers, "X-fastapi-cache": "Miss"}

# FastAPI uses '_StremingResponse' everywhere under the hood.
if str(request.url).rpartition("/")[-1] == "streamed_generative":

async def stream_response() -> AsyncGenerator[bytes, None]:
Expand All @@ -318,7 +320,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
return StreamingResponse(
content=stream_response(),
status_code=response.status_code,
headers=response.headers,
headers=response_header,
media_type=response.media_type,
)
else:
Expand All @@ -337,7 +339,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
return Response(
content=body,
status_code=response.status_code,
headers=response.headers,
headers=response_header,
media_type=response.media_type,
)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/app/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
select_relevant_settings,
strip_path_prefix,
)
from starlette.datastructures import MutableHeaders
from starlette.datastructures import URL, MutableHeaders
from starlette.status import HTTP_401_UNAUTHORIZED
from starlette.types import Message

Expand Down Expand Up @@ -214,7 +214,7 @@ async def test_get_and_set_cache_with_cache_key_not_in_db(response_body):
"headers": {},
},
)

request._url = URL("http://testserver/test")
body = """{"param": "This is request param"}""".encode("utf-8")

async def get_request_body():
Expand Down

0 comments on commit e42f80e

Please sign in to comment.