Skip to content

Commit

Permalink
Boost security
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Jul 27, 2024
1 parent e269fbc commit 277ab6a
Showing 1 changed file with 21 additions and 24 deletions.
45 changes: 21 additions & 24 deletions examples/YOLOv8_server_api/model_downloader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import datetime
import os
from pathlib import Path

import requests
from flask import Blueprint
from flask import current_app as app
from flask import jsonify
from flask import send_from_directory
from flask_limiter import Limiter
Expand All @@ -14,14 +15,14 @@
limiter = Limiter(key_func=get_remote_address)

# Define the directory where the model files are stored
MODELS_DIRECTORY = 'models/pt/'
MODELS_DIRECTORY = Path('models/pt/')
# Define the allowed models
ALLOWED_MODELS = {'yolov8n', 'yolov8s', 'yolov8m', 'yolov8l', 'yolov8x'}
ALLOWED_MODELS = {'best_yolov8l.pt', 'best_yolov8x.pt'}


@models_blueprint.route('/models/<model_name>', methods=['GET'])
@limiter.limit('10 per minute')
def download_model(model_name):
def download_model(model_name: str):
"""
Endpoint to download model files.
Expand All @@ -35,12 +36,10 @@ def download_model(model_name):
if model_name not in ALLOWED_MODELS:
return jsonify({'error': 'Model not found.'}), 404

file_name = f"best_{model_name}.pt"

try:
# Define the external URL for model files
MODEL_URL = (
f"http://changdar-server.mooo.com:28000/models/{file_name}"
f"http://changdar-server.mooo.com:28000/models/{model_name}"
)

# Check last modified time via a HEAD request to the external server
Expand All @@ -51,22 +50,23 @@ def download_model(model_name):
'%a, %d %b %Y %H:%M:%S GMT',
)

# Use os.path.join to safely construct the file path
local_file_path = os.path.join(
MODELS_DIRECTORY, file_name,
)
# Use Path to safely construct the file path
local_file_path = MODELS_DIRECTORY / model_name

# Ensure the constructed path is within the expected directory
common_path = os.path.commonpath(
[local_file_path, MODELS_DIRECTORY],
)
if common_path != MODELS_DIRECTORY:
try:
local_file_path = (
local_file_path.resolve().relative_to(
MODELS_DIRECTORY.resolve(),
)
)
except ValueError:
return jsonify({'error': 'Invalid model name.'}), 400

# Check local file's last modified time
if os.path.exists(local_file_path):
if local_file_path.exists():
local_last_modified = datetime.datetime.fromtimestamp(
os.path.getmtime(local_file_path),
local_file_path.stat().st_mtime,
)
if local_last_modified >= server_last_modified:
return jsonify(
Expand All @@ -89,10 +89,7 @@ def download_model(model_name):
},
), 404

except requests.RequestException as e:
return jsonify(
{
'error': 'Failed to fetch model information.',
'details': str(e),
},
), 500
except requests.RequestException:
# Log the exception details for debugging purposes
app.logger.error('Failed to fetch model information')
return jsonify({'error': 'Failed to fetch model information.'}), 500

0 comments on commit 277ab6a

Please sign in to comment.