Skip to content

Commit

Permalink
Validate model filename as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
robinjhuang committed Aug 8, 2024
1 parent b464717 commit cf774c6
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 3 deletions.
2 changes: 1 addition & 1 deletion model_filemanager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_model_subdirectory, validate_filename
48 changes: 48 additions & 0 deletions model_filemanager/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
False
)

if not validate_filename(model_name):
return DownloadModelStatus(
DownloadStatusType.ERROR,
0,
"Invalid model name",
False
)

file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
existing_file = await check_file_exists(file_path, model_name, progress_callback, relative_path)
if existing_file:
Expand Down Expand Up @@ -99,6 +107,14 @@ def create_model_path(model_name: str, model_directory: str, models_base_dir: st
full_model_dir = os.path.join(models_base_dir, model_directory)
os.makedirs(full_model_dir, exist_ok=True)
file_path = os.path.join(full_model_dir, model_name)

# Ensure the resulting path is still within the base directory
abs_file_path = os.path.abspath(file_path)
abs_base_dir = os.path.abspath(str(models_base_dir))
if os.path.commonprefix([abs_file_path, abs_base_dir]) != abs_base_dir:
raise Exception(f"Invalid model directory: {model_directory}/{model_name}")


relative_path = '/'.join([model_directory, model_name])
return file_path, relative_path

Expand Down Expand Up @@ -187,3 +203,35 @@ def validate_model_subdirectory(model_subdirectory: str) -> bool:
return False

return True

def validate_filename(filename):
"""
Validate a filename to ensure it's safe and doesn't contain any path traversal attempts.
Args:
filename (str): The filename to validate
Returns:
bool: True if the filename is valid, False otherwise
"""
# Check if the filename is empty, None, or just whitespace
if not filename or not filename.strip():
return False

# Check for any directory traversal attempts or invalid characters
if any(char in filename for char in ['..', '/', '\\', '\n', '\r', '\t', '\0']):
return False

# Check if the filename starts with a dot (hidden file)
if filename.startswith('.'):
return False

# Use a whitelist of allowed characters
if not re.match(r'^[a-zA-Z0-9_\-. ]+$', filename):
return False

# Ensure the filename isn't too long
if len(filename) > 255:
return False

return True
26 changes: 24 additions & 2 deletions tests-unit/prompt_server_test/download_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import os
from unittest.mock import AsyncMock, patch, MagicMock
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus
from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename

class AsyncIteratorMock:
"""
Expand Down Expand Up @@ -298,4 +298,26 @@ def test_subdirectory_with_underscore_and_dash():
assert validate_model_subdirectory("valid_model-name") is True

def test_empty_subdirectory():
assert validate_model_subdirectory("") is False
assert validate_model_subdirectory("") is False

@pytest.mark.parametrize("filename, expected", [
("valid_model.safetensors", True),
("valid_model.sft", True),
("another-valid_model.ckpt", True),
("valid model.safetensors", True), # Test with space
("UPPERCASE_MODEL.SAFETENSORS", True),
("model_with.multiple.dots.pt", True),
("", False), # Empty string
(None, False), # None value
("../../../etc/passwd", False), # Path traversal attempt
("/etc/passwd", False), # Absolute path
("\\windows\\system32\\config\\sam", False), # Windows path
(".hidden_file.pt", False), # Hidden file
("invalid<char>.ckpt", False), # Invalid character
("invalid?.ckpt", False), # Another invalid character
("very" * 100 + ".safetensors", False), # Too long filename
("\nmodel_with_newline.pt", False), # Newline character
("model_with_emoji😊.pt", False), # Emoji in filename
])
def test_validate_filename(filename, expected):
assert validate_filename(filename) == expected

0 comments on commit cf774c6

Please sign in to comment.