From cf774c60b63d25d15a2d7c624df77a038bbc30e4 Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Thu, 8 Aug 2024 16:18:36 -0700 Subject: [PATCH] Validate model filename as well. --- model_filemanager/__init__.py | 2 +- model_filemanager/download_models.py | 48 +++++++++++++++++++ .../download_models_test.py | 26 +++++++++- 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/model_filemanager/__init__.py b/model_filemanager/__init__.py index c35e26374d2..e318351c051 100644 --- a/model_filemanager/__init__.py +++ b/model_filemanager/__init__.py @@ -1,2 +1,2 @@ # model_manager/__init__.py -from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory +from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index e802b8a9575..621cbeae64a 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -71,6 +71,14 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht False ) + if not validate_filename(model_name): + return DownloadModelStatus( + DownloadStatusType.ERROR, + 0, + "Invalid model name", + False + ) + file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path) if existing_file: @@ -99,6 +107,14 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st full_model_dir = os.path.join(models_base_dir, model_directory) os.makedirs(full_model_dir, exist_ok=True) file_path = os.path.join(full_model_dir, model_name) + + # Ensure the resulting path is still within the base directory + abs_file_path = os.path.abspath(file_path) + abs_base_dir = os.path.abspath(str(models_base_dir)) + if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir: + raise Exception(f"Invalid model directory: {model_directory}/{model_name}") + + relative_path = '/'.join([model_directory, model_name]) return file_path, relative_path @@ -187,3 +203,35 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool: return False return True + +def validate_filename(filename): + """ + Validate a filename to ensure it's safe and doesn't contain any path traversal attempts. + + Args: + filename (str): The filename to validate + + Returns: + bool: True if the filename is valid, False otherwise + """ + # Check if the filename is empty, None, or just whitespace + if not filename or not filename.strip(): + return False + + # Check for any directory traversal attempts or invalid characters + if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']): + return False + + # Check if the filename starts with a dot (hidden file) + if filename.startswith('.'): + return False + + # Use a whitelist of allowed characters + if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename): + return False + + # Ensure the filename isn't too long + if len(filename) > 255: + return False + + return True diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index 26dd94d4cce..09d8fdcb4f7 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -4,7 +4,7 @@ import itertools import os from unittest.mock import AsyncMock, patch, MagicMock -from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus +from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename class AsyncIteratorMock: """ @@ -298,4 +298,26 @@ def test_subdirectory_with_underscore_and_dash(): assert validate_model_subdirectory("valid_model-name") is True def test_empty_subdirectory(): - assert validate_model_subdirectory("") is False \ No newline at end of file + assert validate_model_subdirectory("") is False + +@pytest.mark.parametrize("filename, expected", [ + ("valid_model.safetensors", True), + ("valid_model.sft", True), + ("another-valid_model.ckpt", True), + ("valid model.safetensors", True), # Test with space + ("UPPERCASE_MODEL.SAFETENSORS", True), + ("model_with.multiple.dots.pt", True), + ("", False), # Empty string + (None, False), # None value + ("../../../etc/passwd", False), # Path traversal attempt + ("/etc/passwd", False), # Absolute path + ("\\windows\\system32\\config\\sam", False), # Windows path + (".hidden_file.pt", False), # Hidden file + ("invalid.ckpt", False), # Invalid character + ("invalid?.ckpt", False), # Another invalid character + ("very" * 100 + ".safetensors", False), # Too long filename + ("\nmodel_with_newline.pt", False), # Newline character + ("model_with_emoji😊.pt", False), # Emoji in filename +]) +def test_validate_filename(filename, expected): + assert validate_filename(filename) == expected