Skip to content

Commit

Permalink
Switch Flask to Fastapi in YOLO server api
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Nov 7, 2024
1 parent 9752df4 commit 96bd833
Show file tree
Hide file tree
Showing 18 changed files with 1,025 additions and 790 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
或者

```sh
gunicorn -w 1 -b 0.0.0.0:8000 "examples.object_detection__server_api.app:app"
uvicorn examples.YOLO_server.app:sio_app --host 127.0.0.1 --port 8000
```

2. **向 API 發送請求:**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This section an example implementation of a YOLO Server API, designed to facilit
or

```sh
gunicorn -w 1 -b 0.0.0.0:8000 "examples.object_detection_server_api.app:app"
uvicorn examples.YOLO_server.app:sio_app --host 127.0.0.1 --port 8000
```

4. **Send a request to the API:**
Expand Down
104 changes: 104 additions & 0 deletions examples/YOLO_server/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

from contextlib import asynccontextmanager
from typing import AsyncGenerator

import redis.asyncio as redis
import socketio
from apscheduler.schedulers.background import BackgroundScheduler
from fastapi import FastAPI
from fastapi_jwt import JwtAccessBearer
from fastapi_limiter import FastAPILimiter

from .auth import auth_router
from .config import Settings
from .detection import detection_router
from .model_downloader import models_router
from .models import Base, engine
from .security import update_secret_key

# Instantiate the FastAPI app
app = FastAPI()

# Initialise database tables
Base.metadata.create_all(bind=engine)

# Register API routers for different functionalities
app.include_router(auth_router)
app.include_router(detection_router)
app.include_router(models_router)

# Set up JWT authentication
jwt_access = JwtAccessBearer(secret_key=Settings().authjwt_secret_key)

# Configure background scheduler to refresh secret keys every 30 days
scheduler = BackgroundScheduler()
scheduler.add_job(
func=lambda: update_secret_key(app),
trigger='interval',
days=30,
)
scheduler.start()

# Initialise Socket.IO server for real-time events
sio = socketio.AsyncServer(async_mode='asgi')
sio_app = socketio.ASGIApp(sio, app)


# Define Socket.IO events
@sio.event
async def connect(sid: str, environ: dict) -> None:
"""
Handles client connection event to the Socket.IO server.
Args:
sid (str): The session ID of the connected client.
environ (dict): The environment dictionary for the connection.
"""
print('Client connected:', sid)


@sio.event
async def disconnect(sid: str) -> None:
"""
Handles client disconnection from the Socket.IO server.
Args:
sid (str): The session ID of the disconnected client.
"""
print('Client disconnected:', sid)


# Define lifespan event to manage the application startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""
Context manager to handle application startup and shutdown tasks.
Args:
app (FastAPI): The FastAPI application instance.
"""
# Initialise Redis connection pool for rate limiting and store it in app.state
app.state.redis_pool = await redis.from_url(
'redis://localhost:6379',
encoding='utf-8',
decode_responses=True,
password='_sua6oub4Ss',
)
await FastAPILimiter.init(app.state.redis_pool)

# Yield control to allow application operation
yield

# Shutdown the scheduler and Redis connection pool upon application termination
scheduler.shutdown()
await app.state.redis_pool.close()

# Assign lifespan context to the FastAPI app
app.router.lifespan_context = lifespan


# Main entry point for running the app
if __name__ == '__main__':
import uvicorn
uvicorn.run(sio_app, host='0.0.0.0', port=5000)
53 changes: 53 additions & 0 deletions examples/YOLO_server/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi_jwt import JwtAccessBearer
from pydantic import BaseModel
from sqlalchemy.orm import Session

from .cache import user_cache
from .config import Settings
from .models import get_db
from .models import User

auth_router = APIRouter()
jwt_access = JwtAccessBearer(secret_key=Settings().authjwt_secret_key)


class UserLogin(BaseModel):
username: str
password: str


@auth_router.post('/token')
def create_token(user: UserLogin, db: Session = Depends(get_db)):
print(f"db_user.__dict__ = {user.__dict__}")
db_user = user_cache.get(user.username)
print(db_user)
if not db_user:
db_user = db.query(User).filter(User.username == user.username).first()
if db_user:
user_cache[user.username] = db_user

if not db_user or not db_user.check_password(user.password):
raise HTTPException(
status_code=401, detail='Wrong username or password',
)

if not db_user.is_active:
raise HTTPException(
status_code=403, detail='User account is inactive',
)

if db_user.role not in ['admin', 'model_manager', 'user', 'guest']:
raise HTTPException(
status_code=403, detail='User does not have the required role')

# access_token = jwt_access.create_access_token(subject=user.username)
access_token = jwt_access.create_access_token(
subject={'username': user.username, 'role': db_user.role},
)

return {'access_token': access_token}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Initialise a simple cache to store user data.
from __future__ import annotations

user_cache: dict = {}
user_cache: dict = {}
40 changes: 40 additions & 0 deletions examples/YOLO_server/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

import os

from dotenv import load_dotenv
from pydantic_settings import BaseSettings

# Load environment variables from a .env file
load_dotenv()


class Settings(BaseSettings):
"""
A class to represent the application settings.
Attributes:
authjwt_secret_key (str): The secret key for JWT authentication.
sqlalchemy_database_uri (str): The URI for the SQLAlchemy database
connection.
sqlalchemy_track_modifications (bool): Flag to track modifications in
SQLAlchemy.
"""

authjwt_secret_key: str = os.getenv(
'JWT_SECRET_KEY',
'your_fallback_secret_key',
)
sqlalchemy_database_uri: str = os.getenv(
'DATABASE_URL',
'mysql://user:password@localhost/dbname',
)
sqlalchemy_track_modifications: bool = False

def __init__(self) -> None:
"""
Initialise the Settings instance with environment variables.
If the environment variables are not set, fallback values will be used.
"""
super().__init__() # Ensure the BaseSettings initialisation is called
Loading

0 comments on commit 96bd833

Please sign in to comment.