Skip to content

Commit

Permalink
Use repo owner/name to replace hard coded provider list
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Jul 15, 2024
1 parent 3a5afe3 commit 6c50d83
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 82 deletions.
92 changes: 51 additions & 41 deletions app/frontend_management.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import argparse
import logging
import os
import re
import tempfile
import zipfile
import logging
from functools import cached_property
from typing import TypedDict
from dataclasses import dataclass
from typing_extensions import NotRequired
from functools import cached_property
from pathlib import Path
from typing import TypedDict

import requests
from typing_extensions import NotRequired


REQUEST_TIMEOUT = 10 # seconds


class Asset(TypedDict):
Expand All @@ -30,10 +33,13 @@ class Release(TypedDict):

@dataclass
class FrontEndProvider:
name: str
owner: str
repo: str

@property
def folder_name(self) -> str:
return f"{self.owner}_{self.repo}"

@property
def release_url(self) -> str:
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
Expand All @@ -43,7 +49,7 @@ def all_releases(self) -> list[Release]:
releases = []
api_url = self.release_url
while api_url:
response = requests.get(api_url)
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
releases.extend(response.json())
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
Expand All @@ -56,7 +62,7 @@ def all_releases(self) -> list[Release]:
@cached_property
def latest_release(self) -> Release:
latest_release_url = f"{self.release_url}/latest"
response = requests.get(latest_release_url)
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()

Expand Down Expand Up @@ -84,7 +90,9 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
# Use a temporary file to download the zip content
with tempfile.TemporaryFile() as tmp_file:
headers = {"Accept": "application/octet-stream"}
response = requests.get(asset_url, headers=headers, allow_redirects=True)
response = requests.get(
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
)
response.raise_for_status() # Ensure we got a successful response

# Write the content to the temporary file
Expand All @@ -100,25 +108,12 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:

class FrontendManager:
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "legacy@latest"
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")

PROVIDERS = [
FrontEndProvider(
name="main",
owner="Comfy-Org",
repo="ComfyUI_frontend",
),
FrontEndProvider(
name="legacy",
owner="Comfy-Org",
repo="ComfyUI_frontend_legacy",
),
]

@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str]:
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Args:
value (str): The version string to parse.
Expand All @@ -129,31 +124,26 @@ def parse_version_string(cls, value: str) -> tuple[str, str]:
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = (
r"^("
+ "|".join([provider.name for provider in cls.PROVIDERS])
+ r")@(\d+\.\d+\.\d+|latest)$"
)
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")

return match_result.group(1), match_result.group(2)
return match_result.group(1), match_result.group(2), match_result.group(3)

@classmethod
def add_argument(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"--front-end-version",
type=str,
default=cls.DEFAULT_VERSION_STRING,
help=f"""
help="""
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
download available frontend implementations from GitHub releases.
The version string should be in the format of:
[provider]@[version]
where provider is one of: {", ".join([provider.name for provider in cls.PROVIDERS])}
and version is one of: a valid version number, latest
[repoOwner]/[repoName]@[version]
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
""",
)

Expand All @@ -174,7 +164,7 @@ def is_valid_directory(path: str | None) -> str | None:
)

@classmethod
def init_frontend(cls, version_string: str) -> str:
def init_frontend_unsafe(cls, version_string: str) -> str:
"""
Initializes the frontend for the specified version.
Expand All @@ -185,26 +175,46 @@ def init_frontend(cls, version_string: str) -> str:
str: The path to the initialized frontend.
Raises:
ValueError: If the provider name is not found in the list of providers.
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == cls.DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH

provider_name, version = cls.parse_version_string(version_string)
provider = next(
provider for provider in cls.PROVIDERS if provider.name == provider_name
)
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)

semantic_version = release["tag_name"].lstrip("v")
web_root = str(
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.name / semantic_version
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
f"Downloading frontend({provider_name}) version({semantic_version})"
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
return web_root

@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH
56 changes: 15 additions & 41 deletions tests-unit/comfy/frontend_manager_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import pytest
import os
import argparse
from unittest.mock import patch
import pytest
from requests.exceptions import HTTPError

from app.frontend_management import (
FrontendManager,
FrontEndProvider,
Release,
FrontendManager,
)


Expand Down Expand Up @@ -39,7 +38,6 @@ def mock_releases():
@pytest.fixture
def mock_provider(mock_releases):
provider = FrontEndProvider(
name="test",
owner="test-owner",
repo="test-repo",
)
Expand Down Expand Up @@ -73,49 +71,25 @@ def test_init_frontend_default():
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH


def test_init_frontend_provider_version(mock_provider, mock_releases):
version_string = f"{mock_provider.name}@1.0.0"
with patch("app.frontend_management.download_release_asset_zip") as mock_download:
with patch("os.makedirs") as mock_makedirs:
frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == os.path.join(
FrontendManager.CUSTOM_FRONTENDS_ROOT, mock_provider.name, "1.0.0"
)
mock_makedirs.assert_called_once_with(frontend_path, exist_ok=True)
mock_download.assert_called_once_with(
mock_releases[0], destination_path=frontend_path
)


def test_init_frontend_provider_latest(mock_provider, mock_releases):
version_string = f"{mock_provider.name}@latest"
with patch("app.frontend_management.download_release_asset_zip") as mock_download:
with patch("os.makedirs") as mock_makedirs:
frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == os.path.join(
FrontendManager.CUSTOM_FRONTENDS_ROOT, mock_provider.name, "2.0.0"
)
mock_makedirs.assert_called_once_with(frontend_path, exist_ok=True)
mock_download.assert_called_once_with(
mock_releases[1], destination_path=frontend_path
)

def test_init_frontend_invalid_version():
version_string = "[email protected]"
with pytest.raises(ValueError):
FrontendManager.init_frontend(version_string)
version_string = "test-owner/test-repo@1.100.99"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)


def test_init_frontend_invalid_provider():
version_string = "invalid@latest"
with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.init_frontend(version_string)
version_string = "invalid/invalid@latest"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)


def test_parse_version_string():
version_string = "[email protected]"
provider, version = FrontendManager.parse_version_string(version_string)
assert provider == "test"
version_string = "owner/[email protected]"
repo_owner, repo_name, version = FrontendManager.parse_version_string(
version_string
)
assert repo_owner == "owner"
assert repo_name == "repo"
assert version == "1.0.0"


Expand Down

0 comments on commit 6c50d83

Please sign in to comment.