From 277ab6a6a12baf42f3eeaf753a7ac4a1d153211b Mon Sep 17 00:00:00 2001 From: yihong1120 Date: Sat, 27 Jul 2024 14:59:02 +0800 Subject: [PATCH] Boost security --- .../YOLOv8_server_api/model_downloader.py | 45 +++++++++---------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/YOLOv8_server_api/model_downloader.py b/examples/YOLOv8_server_api/model_downloader.py index 1861e04..f4767c5 100644 --- a/examples/YOLOv8_server_api/model_downloader.py +++ b/examples/YOLOv8_server_api/model_downloader.py @@ -1,10 +1,11 @@ from __future__ import annotations import datetime -import os +from pathlib import Path import requests from flask import Blueprint +from flask import current_app as app from flask import jsonify from flask import send_from_directory from flask_limiter import Limiter @@ -14,14 +15,14 @@ limiter = Limiter(key_func=get_remote_address) # Define the directory where the model files are stored -MODELS_DIRECTORY = 'models/pt/' +MODELS_DIRECTORY = Path('models/pt/') # Define the allowed models -ALLOWED_MODELS = {'yolov8n', 'yolov8s', 'yolov8m', 'yolov8l', 'yolov8x'} +ALLOWED_MODELS = {'best_yolov8l.pt', 'best_yolov8x.pt'} @models_blueprint.route('/models/', methods=['GET']) @limiter.limit('10 per minute') -def download_model(model_name): +def download_model(model_name: str): """ Endpoint to download model files. @@ -35,12 +36,10 @@ def download_model(model_name): if model_name not in ALLOWED_MODELS: return jsonify({'error': 'Model not found.'}), 404 - file_name = f"best_{model_name}.pt" - try: # Define the external URL for model files MODEL_URL = ( - f"http://changdar-server.mooo.com:28000/models/{file_name}" + f"http://changdar-server.mooo.com:28000/models/{model_name}" ) # Check last modified time via a HEAD request to the external server @@ -51,22 +50,23 @@ def download_model(model_name): '%a, %d %b %Y %H:%M:%S GMT', ) - # Use os.path.join to safely construct the file path - local_file_path = os.path.join( - MODELS_DIRECTORY, file_name, - ) + # Use Path to safely construct the file path + local_file_path = MODELS_DIRECTORY / model_name # Ensure the constructed path is within the expected directory - common_path = os.path.commonpath( - [local_file_path, MODELS_DIRECTORY], - ) - if common_path != MODELS_DIRECTORY: + try: + local_file_path = ( + local_file_path.resolve().relative_to( + MODELS_DIRECTORY.resolve(), + ) + ) + except ValueError: return jsonify({'error': 'Invalid model name.'}), 400 # Check local file's last modified time - if os.path.exists(local_file_path): + if local_file_path.exists(): local_last_modified = datetime.datetime.fromtimestamp( - os.path.getmtime(local_file_path), + local_file_path.stat().st_mtime, ) if local_last_modified >= server_last_modified: return jsonify( @@ -89,10 +89,7 @@ def download_model(model_name): }, ), 404 - except requests.RequestException as e: - return jsonify( - { - 'error': 'Failed to fetch model information.', - 'details': str(e), - }, - ), 500 + except requests.RequestException: + # Log the exception details for debugging purposes + app.logger.error('Failed to fetch model information') + return jsonify({'error': 'Failed to fetch model information.'}), 500