-
-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Switch Flask to Fastapi in YOLO server api
- Loading branch information
1 parent
9752df4
commit 7c54585
Showing
18 changed files
with
1,025 additions
and
790 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from __future__ import annotations | ||
|
||
from contextlib import asynccontextmanager | ||
from typing import AsyncGenerator | ||
|
||
import redis.asyncio as redis | ||
import socketio | ||
from apscheduler.schedulers.background import BackgroundScheduler | ||
from fastapi import FastAPI | ||
from fastapi_jwt import JwtAccessBearer | ||
from fastapi_limiter import FastAPILimiter | ||
|
||
from .auth import auth_router | ||
from .config import Settings | ||
from .detection import detection_router | ||
from .model_downloader import models_router | ||
from .models import Base, engine | ||
from .security import update_secret_key | ||
|
||
# Instantiate the FastAPI app | ||
app = FastAPI() | ||
|
||
# Initialise database tables | ||
Base.metadata.create_all(bind=engine) | ||
|
||
# Register API routers for different functionalities | ||
app.include_router(auth_router) | ||
app.include_router(detection_router) | ||
app.include_router(models_router) | ||
|
||
# Set up JWT authentication | ||
jwt_access = JwtAccessBearer(secret_key=Settings().authjwt_secret_key) | ||
|
||
# Configure background scheduler to refresh secret keys every 30 days | ||
scheduler = BackgroundScheduler() | ||
scheduler.add_job( | ||
func=lambda: update_secret_key(app), | ||
trigger='interval', | ||
days=30, | ||
) | ||
scheduler.start() | ||
|
||
# Initialise Socket.IO server for real-time events | ||
sio = socketio.AsyncServer(async_mode='asgi') | ||
sio_app = socketio.ASGIApp(sio, app) | ||
|
||
|
||
# Define Socket.IO events | ||
@sio.event | ||
async def connect(sid: str, environ: dict) -> None: | ||
""" | ||
Handles client connection event to the Socket.IO server. | ||
Args: | ||
sid (str): The session ID of the connected client. | ||
environ (dict): The environment dictionary for the connection. | ||
""" | ||
print('Client connected:', sid) | ||
|
||
|
||
@sio.event | ||
async def disconnect(sid: str) -> None: | ||
""" | ||
Handles client disconnection from the Socket.IO server. | ||
Args: | ||
sid (str): The session ID of the disconnected client. | ||
""" | ||
print('Client disconnected:', sid) | ||
|
||
|
||
# Define lifespan event to manage the application startup and shutdown | ||
@asynccontextmanager | ||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: | ||
""" | ||
Context manager to handle application startup and shutdown tasks. | ||
Args: | ||
app (FastAPI): The FastAPI application instance. | ||
""" | ||
# Initialise Redis connection pool for rate limiting and store it in app.state | ||
app.state.redis_pool = await redis.from_url( | ||
'redis://localhost:6379', | ||
encoding='utf-8', | ||
decode_responses=True, | ||
password='_sua6oub4Ss', | ||
) | ||
await FastAPILimiter.init(app.state.redis_pool) | ||
|
||
# Yield control to allow application operation | ||
yield | ||
|
||
# Shutdown the scheduler and Redis connection pool upon application termination | ||
scheduler.shutdown() | ||
await app.state.redis_pool.close() | ||
|
||
# Assign lifespan context to the FastAPI app | ||
app.router.lifespan_context = lifespan | ||
|
||
|
||
# Main entry point for running the app | ||
if __name__ == '__main__': | ||
import uvicorn | ||
uvicorn.run(sio_app, host='0.0.0.0', port=5000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from __future__ import annotations | ||
|
||
from fastapi import APIRouter | ||
from fastapi import Depends | ||
from fastapi import HTTPException | ||
from fastapi_jwt import JwtAccessBearer | ||
from pydantic import BaseModel | ||
from sqlalchemy.orm import Session | ||
|
||
from .cache import user_cache | ||
from .config import Settings | ||
from .models import get_db | ||
from .models import User | ||
|
||
auth_router = APIRouter() | ||
jwt_access = JwtAccessBearer(secret_key=Settings().authjwt_secret_key) | ||
|
||
|
||
class UserLogin(BaseModel): | ||
username: str | ||
password: str | ||
|
||
|
||
@auth_router.post('/token') | ||
def create_token(user: UserLogin, db: Session = Depends(get_db)): | ||
print(f"db_user.__dict__ = {user.__dict__}") | ||
db_user = user_cache.get(user.username) | ||
print(db_user) | ||
if not db_user: | ||
db_user = db.query(User).filter(User.username == user.username).first() | ||
if db_user: | ||
user_cache[user.username] = db_user | ||
|
||
if not db_user or not db_user.check_password(user.password): | ||
raise HTTPException( | ||
status_code=401, detail='Wrong username or password', | ||
) | ||
|
||
if not db_user.is_active: | ||
raise HTTPException( | ||
status_code=403, detail='User account is inactive', | ||
) | ||
|
||
if db_user.role not in ['admin', 'model_manager', 'user', 'guest']: | ||
raise HTTPException( | ||
status_code=403, detail='User does not have the required role') | ||
|
||
# access_token = jwt_access.create_access_token(subject=user.username) | ||
access_token = jwt_access.create_access_token( | ||
subject={'username': user.username, 'role': db_user.role}, | ||
) | ||
|
||
return {'access_token': access_token} |
2 changes: 1 addition & 1 deletion
2
examples/YOLO_server_api/cache.py → examples/YOLO_server/cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Initialise a simple cache to store user data. | ||
from __future__ import annotations | ||
|
||
user_cache: dict = {} | ||
user_cache: dict = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
|
||
from dotenv import load_dotenv | ||
from pydantic_settings import BaseSettings | ||
|
||
# Load environment variables from a .env file | ||
load_dotenv() | ||
|
||
|
||
class Settings(BaseSettings): | ||
""" | ||
A class to represent the application settings. | ||
Attributes: | ||
authjwt_secret_key (str): The secret key for JWT authentication. | ||
sqlalchemy_database_uri (str): The URI for the SQLAlchemy database | ||
connection. | ||
sqlalchemy_track_modifications (bool): Flag to track modifications in | ||
SQLAlchemy. | ||
""" | ||
|
||
authjwt_secret_key: str = os.getenv( | ||
'JWT_SECRET_KEY', | ||
'your_fallback_secret_key', | ||
) | ||
sqlalchemy_database_uri: str = os.getenv( | ||
'DATABASE_URL', | ||
'mysql://user:password@localhost/dbname', | ||
) | ||
sqlalchemy_track_modifications: bool = False | ||
|
||
def __init__(self) -> None: | ||
""" | ||
Initialise the Settings instance with environment variables. | ||
If the environment variables are not set, fallback values will be used. | ||
""" | ||
super().__init__() # Ensure the BaseSettings initialisation is called |
Oops, something went wrong.