Skip to content

Commit

Permalink
Merge branch 'main' into vedat/fea-fal-file-no-scale
Browse files Browse the repository at this point in the history
  • Loading branch information
badayvedat authored Nov 13, 2024
2 parents 882171d + 0c8034e commit a4b480f
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Literal, TypeVar

import fastapi
import grpc.aio as async_grpc
import httpx
from fastapi import FastAPI
from isolate.server import definitions

import fal.api
Expand Down Expand Up @@ -304,7 +304,7 @@ def collect_routes(self) -> dict[RouteSignature, Callable[..., Any]]:
}

@asynccontextmanager
async def lifespan(self, app: FastAPI):
async def lifespan(self, app: fastapi.FastAPI):
await _call_any_fn(self.setup)
try:
yield
Expand All @@ -320,7 +320,7 @@ def setup(self):
def teardown(self):
"""Teardown the application after serving."""

def _add_extra_middlewares(self, app: FastAPI):
def _add_extra_middlewares(self, app: fastapi.FastAPI):
@app.middleware("http")
async def provide_hints_headers(request, call_next):
response = await call_next(request)
Expand Down Expand Up @@ -361,21 +361,42 @@ async def set_global_object_preference(request, call_next):

@app.middleware("http")
async def set_request_id(request, call_next):
# NOTE: Setting request_id is not supported for websocket/realtime endpoints

if self.isolate_channel is None:
grpc_port = os.environ.get("NOMAD_ALLOC_PORT_grpc")
self.isolate_channel = await open_isolate_channel(
f"localhost:{grpc_port}"
)

request_id = request.headers.get(REQUEST_ID_KEY)
if request_id is not None:
await _set_logger_labels(
{"fal_request_id": request_id}, channel=self.isolate_channel
)
try:
if request_id is None:
# Cut it short
return await call_next(request)
finally:
await _set_logger_labels({}, channel=self.isolate_channel)

await _set_logger_labels(
{"fal_request_id": request_id}, channel=self.isolate_channel
)

async def _unset_at_end():
await _set_logger_labels({}, channel=self.isolate_channel) # type: ignore

try:
response: fastapi.responses.Response = await call_next(request)
except BaseException:
await _unset_at_end()
raise
else:
# We need to wait for the entire response to be sent before
# we can set the logger labels back to the default.
background_tasks = fastapi.BackgroundTasks()
background_tasks.add_task(_unset_at_end)
if response.background:
# We normally have no background tasks, but we should handle it
background_tasks.add_task(response.background)
response.background = background_tasks

return response

@app.exception_handler(RequestCancelledException)
async def value_error_exception_handler(
Expand All @@ -388,7 +409,7 @@ async def value_error_exception_handler(
# the connection without receiving a response
return JSONResponse({"detail": str(exc)}, 499)

def _add_extra_routes(self, app: FastAPI):
def _add_extra_routes(self, app: fastapi.FastAPI):
@app.get("/health")
def health():
return self.health()
Expand Down

0 comments on commit a4b480f

Please sign in to comment.