From 1753700ce4a06326872255848b283a24636d4907 Mon Sep 17 00:00:00 2001 From: Loner Date: Mon, 19 Feb 2024 19:46:58 +0800 Subject: [PATCH] for cascade (#7) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update requirements.txt the UI launches with one missing module `torchvision`. spits out a `ModuleNotFoundError`. installing `torchvision` module fixed it. * Fix hiding dom widgets. * Fix lowvram mode not working with unCLIP and Revision code. * Fix taesd VAE in lowvram mode. * Add title to the API workflow json. (#2380) * Add `title` to the API workflow json. * API: Move `title` to `_meta` dictionary, imply unused. * Only add _meta title to api prompt when dev mode is enabled in UI. * Fix clip vision lowvram mode not working. * Cleanup. * Load weights that can't be lowvramed to target device. * Use function to calculate model size in model patcher. * Fix VALIDATE_INPUTS getting called multiple times. Allow VALIDATE_INPUTS to only validate specific inputs. * This cache timeout is pretty useless in practice. * Add argument to run the VAE on the CPU. * Reregister nodes when pressing refresh button. * Add a denoise parameter to BasicScheduler node. * Add node id and prompt id to websocket progress packet. * Remove useless code. * Auto detect out_channels from model. * Fix issue when websocket is deleted when data is being sent. * Refactor VAE code. Replace constants with downscale_ratio and latent_channels. * Fix regression. * Add support for the stable diffusion x4 upscaling model. This is an old model. Load the checkpoint like a regular one and use the new SD_4XUpscale_Conditioning node. * Fix model patches not working in custom sampling scheduler nodes. * Implement noise augmentation for SD 4X upscale model. * Add a /free route to unload models or free all memory. A POST request to /free with: {"unload_models":true} will unload models from vram. A POST request to /free with: {"free_memory":true} will unload models and free all cached data from the last run workflow. * StableZero123_Conditioning_Batched node. This node lets you generate a batch of images with different elevations or azimuths by setting the elevation_batch_increment and/or azimuth_batch_increment. It also sets the batch index for the latents so that the same init noise is used on each frame. * Fix BasicScheduler issue with Loras. * fix: `/free` handler function name * Implement attention mask on xformers. * Support attention mask in split attention. * Add attention mask support to sub quad attention. * Update optimized_attention_for_device function for new functions that support masked attention. * Support properly loading images with mode I. * Store user settings/data on the server and multi user support (#2160) * wip per user data * Rename, hide menu * better error rework default user * store pretty * Add userdata endpoints Change nodetemplates to userdata * add multi user message * make normal arg * Fix tests * Ignore user dir * user tests * Changed to default to browser storage and add server-storage arg * fix crash on empty templates * fix settings added before load * ignore parse errors * Fix issue with user manager parent dir not being created. * Support I mode images in LoadImageMask. * Use basic attention implementation for small inputs on old pytorch. * Round up to nearest power of 2 in SAG node to fix some resolution issues. * Fix issue when using multiple t2i adapters with batched images. * Skip SAG when latent is too small. * Add InpaintModelConditioning node. This is an alternative to VAE Encode for inpaint that should work with lower denoise. This is a different take on #2501 * Don't round noise mask. * Resolved crashing nodes caused by `FileNotFoundError` during directory traversal - Implemented a `try-except` block in the `recursive_search` function to handle `FileNotFoundError` gracefully. - When encountering a file or directory path that cannot be accessed (causing `FileNotFoundError`), the code now logs a warning and skips processing for that specific path instead of crashing the node (CheckpointLoaderSimple was usually the first to break). This allows the rest of the directory traversal to proceed without interruption. * Add error, status to /history endpoint * Fix hypertile issue with high depths. * Make server storage the default. Remove --server-storage argument. * Clear status notes on execution start. * Rename status notes to status messages. I think message describes them better. * Fix modifiers triggering key down checks * add setting to change control after generate to run before * export function * Manage group nodes (#2455) * wip group manage * prototyping ui * tweaks * wip * wip * more wip * fixes add deletion * Fix tests * fixes * Remove test code * typo * fix crash when link is invalid * Fix crash on group render * Adds copy image option if browser feature available (#2544) * Adds copy image option if browser feature available * refactor * Make unclip more deterministic. Pass a seed argument note that this might make old unclip images different. * Add error handling to initial fix to keep cache intact * Only auto enable bf16 VAE on nvidia GPUs that actually support it. * Fix logging not checking onChange * Auto queue on change (#2542) * Add toggle to enable auto queue when graph is changed * type fix * better * better alignment * Change undoredo to not ignore inputs when autoqueue in change mode * Fix renaming upload widget (#2554) * Fix renaming upload widget * Allow custom name * Jack/workflow (#3) * modified: web/scripts/app.js modified: web/scripts/utils.js * unformat * fix: workflow id (#4) * Don't use PEP 604 type hints, to stay compatible with Python<3.10. * Add unfinished ImageOnlyCheckpointSave node to save a SVD checkpoint. This node is unfinished, SVD checkpoints saved with this node will work with ComfyUI but not with anything else. * Move some nodes to model_patches section. * Remove useless import. * Fix for the extracting issue on windows. * Fix queue on change to respect auto queue checkbox (#2608) * Fix render on change not respecting auto queue checkbox Fix issue where autoQueueEnabled checkbox is ignored for changes if autoQueueMode is left on `change` * Make check more specific * Cleanup some unused imports. * Fix potential turbo scheduler model patching issue. * Ability to hide menu Responsive setting screen Touch events for zooming/context menu * typo fix - calculate_sigmas_scheduler (#2619) self.scheduler -> scheduler_name Co-authored-by: Lt.Dr.Data * Support refresh on group node combos (#2625) * Support refresh on group node combos * fix check * Sync litegraph with repo. https://github.com/comfyanonymous/litegraph.js/pull/4 * Jack/load custom nodes (#5) * update custom nodes * fix order * Add experimental photomaker nodes. Put the model file in models/photomaker and use PhotoMakerLoader. Then use PhotoMakerEncode with the keyword "photomaker" to apply the image * Remove some unused imports. * Cleanups. * Add a LatentBatchSeedBehavior node. This lets you set it so the latents can use the same seed for the sampling on every image in the batch. * Fix some issues with --gpu-only * Remove useless code. * Add node to set only the conditioning area strength. * Make auto saved workflow stored per tab * fix: inpaint on mask editor bottom area * Put VAE key name in model config. * Fix crash when no widgets on customized group node * Fix scrolling with lots of nodes * Litegraph node search improvements. See: https://github.com/comfyanonymous/litegraph.js/pull/5 * Update readme for new pytorch 2.2 release. * feat: better pen support for mask editor - alt-drag: erase - shift-drag(up/down): zoom in/out * use local storage * add increment-wrap as option to ValueControlWidget when isCombo, which loops back to 0 when at end of list * Fix frontend webp prompt handling * changed default of LatentBatchSeedBehavior to fixed * Always use fp16 for the text encoders. * Mask editor: semitransparent brush, brush color modes * Speed up SDXL on 16xx series with fp16 weights and manual cast. * Don't use is_bf16_supported to check for fp16 support. * Document IS_CHANGED in the example custom node. * Make minimum tile size the size of the overlap. * Support linking converted inputs from api json * Sync litegraph to repo. https://github.com/comfyanonymous/litegraph.js/pull/6 * Don't use numpy for calculating sigmas. * Allow custom samplers to request discard penultimate sigma * Add batch number to filename with %batch_num% Allow configurable addition of batch number to output file name. * Add a way to set different conditioning for the controlnet. * Fix infinite while loop being possible in ddim_scheduler * Add a node to give the controlnet a prompt different from the unet. * Safari: Draws certain elements on CPU. In case of search popup, can cause 10 seconds+ main thread lock due to painting. (#2763) * lets toggle this setting first. * also makes it easier for debug. I'll be honest this is generally preferred behavior as well for me but I ain't no power user shrug. * attempting trick to put the work for filter: brightness on GPU as a first attempt before falling back to not using filter for large lists! * revert litegraph.core.js changes from branch * oops * Prevent hideWidget being called twice for same widget Fix for #2766 * Add ImageFromBatch. * Don't init the CLIP model when the checkpoint has no CLIP weights. * Add a disabled SaveImageWebsocket custom node. This node can be used to efficiently get images without saving them to disk when using ComfyUI as a backend. * Small refactor of is_device_* functions. * Stable Cascade Stage A. * Stable Cascade Stage C. * Stable Cascade Stage B. * StableCascade CLIP model support. * Make --force-fp32 disable loading models in bf16. * Support Stable Cascade Stage B lite. * Make Stable Cascade work on old pytorch 2.0 * Fix clip attention mask issues on some hardware. * Manual cast for bf16 on older GPUs. * Implement shift schedule for cascade stage C. * Properly fix attention masks in CLIP with batches. * Fix attention mask batch size in some attention functions. * fp8 weight support for Stable Cascade. * Fix attention masks properly for multiple batches. * Add ModelSamplingStableCascade to control the shift sampling parameter. shift is 2.0 by default on Stage C and 1.0 by default on Stage B. * Fix gligen lowvram mode. * Support additional PNG info. * Support loading the Stable Cascade effnet and previewer as a VAE. The effnet can be used to encode images for img2img with Stage C. * Forgot to commit this. --------- Co-authored-by: Oleksiy Nehlyadyuk Co-authored-by: comfyanonymous Co-authored-by: shiimizu Co-authored-by: AYF Co-authored-by: ramyma Co-authored-by: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Co-authored-by: TFWol <9045213+TFWol@users.noreply.github.com> Co-authored-by: Kristjan Pärt Co-authored-by: Dr.Lt.Data <128333288+ltdrdata@users.noreply.github.com> Co-authored-by: Lt.Dr.Data Co-authored-by: Meowu <474384902@qq.com> Co-authored-by: pksebben Co-authored-by: Chaoses-Ib Co-authored-by: FizzleDorf <1fizzledorf@gmail.com> Co-authored-by: ultimabear Co-authored-by: blepping Co-authored-by: Imran Azeez Co-authored-by: Jedrzej Kosinski Co-authored-by: Steven Lu Co-authored-by: chrisgoringe --- .gitignore | 3 +- .vscode/settings.json | 9 - README.md | 11 +- app/app_settings.py | 54 ++ app/user_manager.py | 140 +++++ comfy/cli_args.py | 4 + comfy/clip_model.py | 6 +- comfy/clip_vision.py | 10 +- comfy/conds.py | 1 - comfy/controlnet.py | 25 +- comfy/diffusers_load.py | 1 - comfy/gligen.py | 54 +- comfy/latent_formats.py | 12 + comfy/ldm/cascade/common.py | 161 ++++++ comfy/ldm/cascade/stage_a.py | 258 +++++++++ comfy/ldm/cascade/stage_b.py | 257 +++++++++ comfy/ldm/cascade/stage_c.py | 271 ++++++++++ comfy/ldm/cascade/stage_c_coder.py | 96 ++++ comfy/ldm/modules/attention.py | 56 +- .../modules/diffusionmodules/openaimodel.py | 9 +- .../ldm/modules/diffusionmodules/upscaling.py | 16 +- comfy/ldm/modules/diffusionmodules/util.py | 4 +- .../ldm/modules/encoders/noise_aug_modules.py | 8 +- comfy/ldm/modules/sub_quadratic_attention.py | 30 +- comfy/model_base.py | 166 +++++- comfy/model_detection.py | 50 +- comfy/model_management.py | 115 +++- comfy/model_patcher.py | 55 +- comfy/model_sampling.py | 61 ++- comfy/ops.py | 50 +- comfy/sample.py | 1 - comfy/samplers.py | 15 +- comfy/sd.py | 142 +++-- comfy/sd1_clip.py | 5 +- comfy/sdxl_clip.py | 22 + comfy/supported_models.py | 84 ++- comfy/supported_models_base.py | 13 +- comfy/taesd/taesd.py | 5 +- comfy/utils.py | 4 + comfy_extras/nodes_cond.py | 25 + comfy_extras/nodes_custom_sampler.py | 12 +- comfy_extras/nodes_freelunch.py | 4 +- comfy_extras/nodes_hypertile.py | 18 +- comfy_extras/nodes_images.py | 20 + comfy_extras/nodes_latent.py | 24 + comfy_extras/nodes_model_advanced.py | 27 + comfy_extras/nodes_model_merging.py | 83 +-- comfy_extras/nodes_photomaker.py | 187 +++++++ comfy_extras/nodes_post_processing.py | 1 + comfy_extras/nodes_sag.py | 4 +- comfy_extras/nodes_sdupscale.py | 47 ++ comfy_extras/nodes_stable3d.py | 46 +- comfy_extras/nodes_stable_cascade.py | 74 +++ comfy_extras/nodes_video_model.py | 17 + custom_nodes/example_node.py.example | 13 + custom_nodes/websocket_image_save.py.disabled | 49 ++ execution.py | 152 ++++-- folder_paths.py | 22 +- main.py | 28 +- models/photomaker/put_photomaker_models_here | 0 nodes.py | 151 ++++-- requirements.txt | 1 + server.py | 22 +- tests-ui/babel.config.json | 3 +- tests-ui/package-lock.json | 20 + tests-ui/package.json | 1 + tests-ui/tests/users.test.js | 295 +++++++++++ tests-ui/utils/index.js | 16 +- tests-ui/utils/setup.js | 36 +- web/extensions/core/groupNode.js | 195 +++++-- web/extensions/core/groupNodeManage.css | 149 ++++++ web/extensions/core/groupNodeManage.js | 422 +++++++++++++++ web/extensions/core/maskeditor.js | 155 +++++- web/extensions/core/nodeTemplates.js | 64 ++- web/extensions/core/simpleTouchSupport.js | 102 ++++ web/extensions/core/undoRedo.js | 40 +- web/extensions/core/widgetInputs.js | 7 + web/index.html | 30 +- web/jsconfig.json | 3 +- web/lib/litegraph.core.js | 19 +- web/lib/litegraph.css | 13 + web/scripts/api.js | 104 +++- web/scripts/app.js | 497 ++++++++++++------ web/scripts/domWidget.js | 5 +- web/scripts/logging.js | 5 +- web/scripts/pnginfo.js | 4 +- web/scripts/ui.js | 367 ++++--------- web/scripts/ui/dialog.js | 32 ++ web/scripts/ui/draggableList.js | 287 ++++++++++ web/scripts/ui/settings.js | 317 +++++++++++ web/scripts/ui/spinner.css | 34 ++ web/scripts/ui/spinner.js | 9 + web/scripts/ui/toggleSwitch.js | 60 +++ web/scripts/ui/userSelection.css | 135 +++++ web/scripts/ui/userSelection.js | 114 ++++ web/scripts/utils.js | 47 +- web/scripts/widgets.js | 75 ++- web/style.css | 136 ++++- 98 files changed, 6182 insertions(+), 927 deletions(-) delete mode 100644 .vscode/settings.json create mode 100644 app/app_settings.py create mode 100644 app/user_manager.py create mode 100644 comfy/ldm/cascade/common.py create mode 100644 comfy/ldm/cascade/stage_a.py create mode 100644 comfy/ldm/cascade/stage_b.py create mode 100644 comfy/ldm/cascade/stage_c.py create mode 100644 comfy/ldm/cascade/stage_c_coder.py create mode 100644 comfy_extras/nodes_cond.py create mode 100644 comfy_extras/nodes_photomaker.py create mode 100644 comfy_extras/nodes_sdupscale.py create mode 100644 comfy_extras/nodes_stable_cascade.py create mode 100644 custom_nodes/websocket_image_save.py.disabled create mode 100644 models/photomaker/put_photomaker_models_here create mode 100644 tests-ui/tests/users.test.js create mode 100644 web/extensions/core/groupNodeManage.css create mode 100644 web/extensions/core/groupNodeManage.js create mode 100644 web/extensions/core/simpleTouchSupport.js create mode 100644 web/scripts/ui/dialog.js create mode 100644 web/scripts/ui/draggableList.js create mode 100644 web/scripts/ui/settings.js create mode 100644 web/scripts/ui/spinner.css create mode 100644 web/scripts/ui/spinner.js create mode 100644 web/scripts/ui/toggleSwitch.js create mode 100644 web/scripts/ui/userSelection.css create mode 100644 web/scripts/ui/userSelection.js diff --git a/.gitignore b/.gitignore index 43c038e4161..9f0389241ea 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ venv/ /web/extensions/* !/web/extensions/logging.js.example !/web/extensions/core/ -/tests-ui/data/object_info.json \ No newline at end of file +/tests-ui/data/object_info.json +/user/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 202121e10fc..00000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "path-intellisense.mappings": { - "../": "${workspaceFolder}/web/extensions/core" - }, - "[python]": { - "editor.defaultFormatter": "ms-python.autopep8" - }, - "python.formatting.provider": "none" -} diff --git a/README.md b/README.md index 167214c05c6..ff3ab64204e 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,8 @@ There is a portable standalone build for Windows that should work for running on Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +If you have trouble extracting it, right click the file -> properties -> unblock + #### How do I share models between another UI and ComfyUI? See the [Config file](extra_model_paths.yaml.example) to set the search paths for models. In the standalone windows build you can find this file in the ComfyUI directory. Rename this file to extra_model_paths.yaml and edit it with your favorite text editor. @@ -93,16 +95,15 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your VAE in: models/vae -Note: pytorch stable does not support python 3.12 yet. If you have python 3.12 you will have to use the nightly version of pytorch. If you run into issues you should try python 3.11 instead. ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7``` -This is the command to install the nightly with ROCm 5.7 which has a python 3.12 package and might have some performance improvements: +This is the command to install the nightly with ROCm 6.0 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.0``` ### NVIDIA @@ -110,7 +111,7 @@ Nvidia users should install stable pytorch using this command: ```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121``` -This is the command to install pytorch nightly instead which has a python 3.12 package and might have performance improvements: +This is the command to install pytorch nightly instead which might have performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121``` diff --git a/app/app_settings.py b/app/app_settings.py new file mode 100644 index 00000000000..8c6edc56c1d --- /dev/null +++ b/app/app_settings.py @@ -0,0 +1,54 @@ +import os +import json +from aiohttp import web + + +class AppSettings(): + def __init__(self, user_manager): + self.user_manager = user_manager + + def get_settings(self, request): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + if os.path.isfile(file): + with open(file) as f: + return json.load(f) + else: + return {} + + def save_settings(self, request, settings): + file = self.user_manager.get_request_user_filepath( + request, "comfy.settings.json") + with open(file, "w") as f: + f.write(json.dumps(settings, indent=4)) + + def add_routes(self, routes): + @routes.get("/settings") + async def get_settings(request): + return web.json_response(self.get_settings(request)) + + @routes.get("/settings/{id}") + async def get_setting(request): + value = None + settings = self.get_settings(request) + setting_id = request.match_info.get("id", None) + if setting_id and setting_id in settings: + value = settings[setting_id] + return web.json_response(value) + + @routes.post("/settings") + async def post_settings(request): + settings = self.get_settings(request) + new_settings = await request.json() + self.save_settings(request, {**settings, **new_settings}) + return web.Response(status=200) + + @routes.post("/settings/{id}") + async def post_setting(request): + setting_id = request.match_info.get("id", None) + if not setting_id: + return web.Response(status=400) + settings = self.get_settings(request) + settings[setting_id] = await request.json() + self.save_settings(request, settings) + return web.Response(status=200) \ No newline at end of file diff --git a/app/user_manager.py b/app/user_manager.py new file mode 100644 index 00000000000..209094af15a --- /dev/null +++ b/app/user_manager.py @@ -0,0 +1,140 @@ +import json +import os +import re +import uuid +from aiohttp import web +from comfy.cli_args import args +from folder_paths import user_directory +from .app_settings import AppSettings + +default_user = "default" +users_file = os.path.join(user_directory, "users.json") + + +class UserManager(): + def __init__(self): + global user_directory + + self.settings = AppSettings(self) + if not os.path.exists(user_directory): + os.mkdir(user_directory) + if not args.multi_user: + print("****** User settings have been changed to be stored on the server instead of browser storage. ******") + print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") + + if args.multi_user: + if os.path.isfile(users_file): + with open(users_file) as f: + self.users = json.load(f) + else: + self.users = {} + else: + self.users = {"default": "default"} + + def get_request_user_id(self, request): + user = "default" + if args.multi_user and "comfy-user" in request.headers: + user = request.headers["comfy-user"] + + if user not in self.users: + raise KeyError("Unknown user: " + user) + + return user + + def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): + global user_directory + + if type == "userdata": + root_dir = user_directory + else: + raise KeyError("Unknown filepath type:" + type) + + user = self.get_request_user_id(request) + path = user_root = os.path.abspath(os.path.join(root_dir, user)) + + # prevent leaving /{type} + if os.path.commonpath((root_dir, user_root)) != root_dir: + return None + + parent = user_root + + if file is not None: + # prevent leaving /{type}/{user} + path = os.path.abspath(os.path.join(user_root, file)) + if os.path.commonpath((user_root, path)) != user_root: + return None + + if create_dir and not os.path.exists(parent): + os.mkdir(parent) + + return path + + def add_user(self, name): + name = name.strip() + if not name: + raise ValueError("username not provided") + user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + user_id = user_id + "_" + str(uuid.uuid4()) + + self.users[user_id] = name + + global users_file + with open(users_file, "w") as f: + json.dump(self.users, f) + + return user_id + + def add_routes(self, routes): + self.settings.add_routes(routes) + + @routes.get("/users") + async def get_users(request): + if args.multi_user: + return web.json_response({"storage": "server", "users": self.users}) + else: + user_dir = self.get_request_user_filepath(request, None, create_dir=False) + return web.json_response({ + "storage": "server", + "migrated": os.path.exists(user_dir) + }) + + @routes.post("/users") + async def post_users(request): + body = await request.json() + username = body["username"] + if username in self.users.values(): + return web.json_response({"error": "Duplicate username."}, status=400) + + user_id = self.add_user(username) + return web.json_response(user_id) + + @routes.get("/userdata/{file}") + async def getuserdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + if not os.path.exists(path): + return web.Response(status=404) + + return web.FileResponse(path) + + @routes.post("/userdata/{file}") + async def post_userdata(request): + file = request.match_info.get("file", None) + if not file: + return web.Response(status=400) + + path = self.get_request_user_filepath(request, file) + if not path: + return web.Response(status=403) + + body = await request.read() + with open(path, "wb") as f: + f.write(body) + + return web.Response(status=200) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 8de0adb53ee..b4bbfbfab53 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -66,6 +66,8 @@ def __call__(self, parser, namespace, values, option_string=None): fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.") + fpte_group = parser.add_mutually_exclusive_group() fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") @@ -110,6 +112,8 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") +parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") + if comfy.options.args_parsing: args = parser.parse_args() else: diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 850b5fdbecb..9b82a246b2c 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -57,7 +57,7 @@ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)]) def forward(self, x, mask=None, intermediate_output=None): - optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None) + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) if intermediate_output is not None: if intermediate_output < 0: @@ -97,7 +97,7 @@ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, f x = self.embeddings(input_tokens) mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).unsqueeze(1).unsqueeze(1).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) @@ -151,7 +151,7 @@ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dty def forward(self, pixel_values): embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) - return torch.cat([self.class_embedding.expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight + return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device) class CLIPVision(torch.nn.Module): diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 4564fcfb2a0..8c77ee7a922 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,7 +1,6 @@ -from .utils import load_torch_file, transformers_convert, common_upscale +from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os import torch -import contextlib import json import comfy.ops @@ -41,9 +40,13 @@ def __init__(self, json_config): self.model.eval() self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) + def get_sd(self): + return self.model.state_dict() + def encode_image(self, image): comfy.model_management.load_model_gpu(self.patcher) pixel_values = clip_preprocess(image.to(self.load_device)).float() @@ -76,6 +79,9 @@ def convert_to_transformers(sd, prefix): sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1) sd = transformers_convert(sd, prefix, "vision_model.", 48) + else: + replace_prefix = {prefix: ""} + sd = state_dict_prefix_replace(sd, replace_prefix) return sd def load_clipvision_from_sd(sd, prefix="", convert_keys=False): diff --git a/comfy/conds.py b/comfy/conds.py index 6cff2518400..23fa48872d6 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,4 +1,3 @@ -import enum import torch import math import comfy.utils diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 8404054f38f..416197586a1 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -1,7 +1,6 @@ import torch import math import os -import contextlib import comfy.utils import comfy.model_management import comfy.model_detection @@ -126,7 +125,10 @@ def control_merge(self, control_input, control_output, control_prev, output_dtyp if o[i] is None: o[i] = prev_val else: - o[i] += prev_val + if o[i].shape[0] < prev_val.shape[0]: + o[i] = prev_val + o[i] + else: + o[i] += prev_val return out class ControlNet(ControlBase): @@ -164,7 +166,7 @@ def get_control(self, x_noisy, t, cond, batched_number): if x_noisy.shape[0] != self.cond_hint.shape[0]: self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) - context = cond['c_crossattn'] + context = cond.get('crossattn_controlnet', cond['c_crossattn']) y = cond.get('y', None) if y is not None: y = y.to(dtype) @@ -316,9 +318,10 @@ def load_controlnet(ckpt_path, model=None): return ControlLora(controlnet_data) controlnet_config = None + supported_inference_dtypes = None + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format - unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype) + controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data) diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias" @@ -378,12 +381,20 @@ def load_controlnet(ckpt_path, model=None): return net if controlnet_config is None: - unet_dtype = comfy.model_management.unet_dtype() - controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True) + supported_inference_dtypes = model_config.supported_inference_dtypes + controlnet_config = model_config.unet_config + load_device = comfy.model_management.get_torch_device() + if supported_inference_dtypes is None: + unet_dtype = comfy.model_management.unet_dtype() + else: + unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes) + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) if manual_cast_dtype is not None: controlnet_config["operations"] = comfy.ops.manual_cast + controlnet_config["dtype"] = unet_dtype controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index c0b420e7966..98b888a1939 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -1,4 +1,3 @@ -import json import os import comfy.sd diff --git a/comfy/gligen.py b/comfy/gligen.py index 8d182839e05..592522767e9 100644 --- a/comfy/gligen.py +++ b/comfy/gligen.py @@ -1,8 +1,9 @@ import torch -from torch import nn, einsum +from torch import nn from .ldm.modules.attention import CrossAttention from inspect import isfunction - +import comfy.ops +ops = comfy.ops.manual_cast def exists(val): return val is not None @@ -22,7 +23,7 @@ def default(val, d): class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) + self.proj = ops.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) @@ -35,14 +36,14 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( - nn.Linear(dim, inner_dim), + ops.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) + ops.Linear(inner_dim, dim_out) ) def forward(self, x): @@ -57,11 +58,12 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): query_dim=query_dim, context_dim=context_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -87,17 +89,18 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( query_dim=query_dim, context_dim=query_dim, heads=n_heads, - dim_head=d_head) + dim_head=d_head, + operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -126,14 +129,14 @@ def __init__(self, query_dim, context_dim, n_heads, d_head): # we need a linear projection since we need cat visual feature and obj # feature - self.linear = nn.Linear(context_dim, query_dim) + self.linear = ops.Linear(context_dim, query_dim) self.attn = CrossAttention( - query_dim=query_dim, context_dim=query_dim, dim_head=d_head) + query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops) self.ff = FeedForward(query_dim, glu=True) - self.norm1 = nn.LayerNorm(query_dim) - self.norm2 = nn.LayerNorm(query_dim) + self.norm1 = ops.LayerNorm(query_dim) + self.norm2 = ops.LayerNorm(query_dim) self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) @@ -201,11 +204,11 @@ def __init__(self, in_dim, out_dim, fourier_freqs=8): self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy self.linears = nn.Sequential( - nn.Linear(self.in_dim + self.position_dim, 512), + ops.Linear(self.in_dim + self.position_dim, 512), nn.SiLU(), - nn.Linear(512, 512), + ops.Linear(512, 512), nn.SiLU(), - nn.Linear(512, out_dim), + ops.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter( @@ -215,16 +218,15 @@ def __init__(self, in_dim, out_dim, fourier_freqs=8): def forward(self, boxes, masks, positive_embeddings): B, N, _ = boxes.shape - dtype = self.linears[0].weight.dtype - masks = masks.unsqueeze(-1).to(dtype) - positive_embeddings = positive_embeddings.to(dtype) + masks = masks.unsqueeze(-1) + positive_embeddings = positive_embeddings # embedding position (it may includes padding as placeholder) - xyxy_embedding = self.fourier_embedder(boxes.to(dtype)) # B*N*4 --> B*N*C + xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C # learnable null embedding - positive_null = self.null_positive_feature.view(1, 1, -1) - xyxy_null = self.null_position_feature.view(1, 1, -1) + positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) + xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * \ @@ -251,7 +253,7 @@ def _set_position(self, boxes, masks, positive_embeddings): def func(x, extra_options): key = extra_options["transformer_index"] module = self.module_list[key] - return module(x, objs) + return module(x, objs.to(device=x.device, dtype=x.dtype)) return func def set_position(self, latent_image_shape, position_params, device): diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index c209087e0cc..68fd73d0b5d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -33,3 +33,15 @@ def __init__(self): [-0.3112, -0.2359, -0.2076] ] self.taesd_decoder_name = "taesdxl_decoder" + +class SD_X4(LatentFormat): + def __init__(self): + self.scale_factor = 0.08333 + +class SC_Prior(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 + +class SC_B(LatentFormat): + def __init__(self): + self.scale_factor = 1.0 diff --git a/comfy/ldm/cascade/common.py b/comfy/ldm/cascade/common.py new file mode 100644 index 00000000000..124902c09a4 --- /dev/null +++ b/comfy/ldm/cascade/common.py @@ -0,0 +1,161 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import torch.nn as nn +from comfy.ldm.modules.attention import optimized_attention + +class Linear(torch.nn.Linear): + def reset_parameters(self): + return None + +class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + + self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, q, k, v): + q = self.to_q(q) + k = self.to_k(k) + v = self.to_v(v) + + out = optimized_attention(q, k, v, self.heads) + + return self.out_proj(out) + +class Attention2D(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations) + # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device) + + def forward(self, x, kv, self_attn=False): + orig_shape = x.shape + x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4 + if self_attn: + kv = torch.cat([x, kv], dim=1) + # x = self.attn(x, kv, kv, need_weights=False)[0] + x = self.attn(x, kv, kv) + x = x.permute(0, 2, 1).view(*orig_shape) + return x + + +def LayerNorm2d_op(operations): + class LayerNorm2d(operations.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return LayerNorm2d + +class GlobalResponseNorm(nn.Module): + "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x + + +class ResBlock(nn.Module): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2): + super().__init__() + self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device) + # self.depthwise = SAMBlock(c, num_heads, expansion) + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +class AttnBlock(nn.Module): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.self_attn = self_attn + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations) + self.kv_mapper = nn.Sequential( + nn.SiLU(), + operations.Linear(c_cond, c, dtype=dtype, device=device) + ) + + def forward(self, x, kv): + kv = self.kv_mapper(kv) + x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) + return x + + +class FeedForwardBlock(nn.Module): + def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.channelwise = nn.Sequential( + operations.Linear(c, c * 4, dtype=dtype, device=device), + nn.GELU(), + GlobalResponseNorm(c * 4, dtype=dtype, device=device), + nn.Dropout(dropout), + operations.Linear(c * 4, c, dtype=dtype, device=device) + ) + + def forward(self, x): + x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + + +class TimestepBlock(nn.Module): + def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None): + super().__init__() + self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)) + + def forward(self, x, t): + t = t.chunk(len(self.conds) + 1, dim=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1) + a, b = a + ac, b + bc + return x * (1 + a) + b diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py new file mode 100644 index 00000000000..260ccfc0b5d --- /dev/null +++ b/comfy/ldm/cascade/stage_a.py @@ -0,0 +1,258 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +from torch.autograd import Function + +class vector_quantize(Function): + @staticmethod + def forward(ctx, x, codebook): + with torch.no_grad(): + codebook_sqr = torch.sum(codebook ** 2, dim=1) + x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) + + dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) + _, indices = dist.min(dim=1) + + ctx.save_for_backward(indices, codebook) + ctx.mark_non_differentiable(indices) + + nn = torch.index_select(codebook, 0, indices) + return nn, indices + + @staticmethod + def backward(ctx, grad_output, grad_indices): + grad_inputs, grad_codebook = None, None + + if ctx.needs_input_grad[0]: + grad_inputs = grad_output.clone() + if ctx.needs_input_grad[1]: + # Gradient wrt. the codebook + indices, codebook = ctx.saved_tensors + + grad_codebook = torch.zeros_like(codebook) + grad_codebook.index_add_(0, indices, grad_output) + + return (grad_inputs, grad_codebook) + + +class VectorQuantize(nn.Module): + def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): + """ + Takes an input of variable size (as long as the last dimension matches the embedding size). + Returns one tensor containing the nearest neigbour embeddings to each of the inputs, + with the same size as the input, vq and commitment components for the loss as a touple + in the second output and the indices of the quantized vectors in the third: + quantized, (vq_loss, commit_loss), indices + """ + super(VectorQuantize, self).__init__() + + self.codebook = nn.Embedding(k, embedding_size) + self.codebook.weight.data.uniform_(-1./k, 1./k) + self.vq = vector_quantize.apply + + self.ema_decay = ema_decay + self.ema_loss = ema_loss + if ema_loss: + self.register_buffer('ema_element_count', torch.ones(k)) + self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) + + def _laplace_smoothing(self, x, epsilon): + n = torch.sum(x) + return ((x + epsilon) / (n + x.size(0) * epsilon) * n) + + def _updateEMA(self, z_e_x, indices): + mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() + elem_count = mask.sum(dim=0) + weight_sum = torch.mm(mask.t(), z_e_x) + + self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) + self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) + self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) + + self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) + + def idx2vq(self, idx, dim=-1): + q_idx = self.codebook(idx) + if dim != -1: + q_idx = q_idx.movedim(-1, dim) + return q_idx + + def forward(self, x, get_losses=True, dim=-1): + if dim != -1: + x = x.movedim(dim, -1) + z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x + z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) + vq_loss, commit_loss = None, None + if self.ema_loss and self.training: + self._updateEMA(z_e_x.detach(), indices.detach()) + # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss + z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) + if get_losses: + vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() + commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() + + z_q_x = z_q_x.view(x.shape) + if dim != -1: + z_q_x = z_q_x.movedim(-1, dim) + return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) + + +class ResBlock(nn.Module): + def __init__(self, c, c_hidden): + super().__init__() + # depthwise/attention + self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.depthwise = nn.Sequential( + nn.ReplicationPad2d(1), + nn.Conv2d(c, c, kernel_size=3, groups=c) + ) + + # channelwise + self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.Sequential( + nn.Linear(c, c_hidden), + nn.GELU(), + nn.Linear(c_hidden, c), + ) + + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) + + # Init weights + def _basic_init(module): + if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def _norm(self, x, norm): + return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + def forward(self, x): + mods = self.gammas + + x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1] + try: + x = x + self.depthwise(x_temp) * mods[2] + except: #operation not implemented for bf16 + x_temp = self.depthwise[0](x_temp.float()).to(x.dtype) + x = x + self.depthwise[1](x_temp) * mods[2] + + x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4] + x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] + + return x + + +class StageA(nn.Module): + def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, + scale_factor=0.43): # 0.3764 + super().__init__() + self.c_latent = c_latent + self.scale_factor = scale_factor + c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))] + + # Encoder blocks + self.in_block = nn.Sequential( + nn.PixelUnshuffle(2), + nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ) + down_blocks = [] + for i in range(levels): + if i > 0: + down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + block = ResBlock(c_levels[i], c_levels[i] * 4) + down_blocks.append(block) + down_blocks.append(nn.Sequential( + nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 + )) + self.down_blocks = nn.Sequential(*down_blocks) + self.down_blocks[0] + + self.codebook_size = codebook_size + self.vquantizer = VectorQuantize(c_latent, k=codebook_size) + + # Decoder blocks + up_blocks = [nn.Sequential( + nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + )] + for i in range(levels): + for j in range(bottleneck_blocks if i == 0 else 1): + block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) + up_blocks.append(block) + if i < levels - 1: + up_blocks.append( + nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + padding=1)) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_block = nn.Sequential( + nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + nn.PixelShuffle(2), + ) + + def encode(self, x, quantize=False): + x = self.in_block(x) + x = self.down_blocks(x) + if quantize: + qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1) + return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25 + else: + return x / self.scale_factor + + def decode(self, x): + x = x * self.scale_factor + x = self.up_blocks(x) + x = self.out_block(x) + return x + + def forward(self, x, quantize=False): + qe, x, _, vq_loss = self.encode(x, quantize) + x = self.decode(qe) + return x, vq_loss + + +class Discriminator(nn.Module): + def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6): + super().__init__() + d = max(depth - 3, 3) + layers = [ + nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.LeakyReLU(0.2), + ] + for i in range(depth - 1): + c_in = c_hidden // (2 ** max((d - i), 0)) + c_out = c_hidden // (2 ** max((d - 1 - i), 0)) + layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.InstanceNorm2d(c_out)) + layers.append(nn.LeakyReLU(0.2)) + self.encoder = nn.Sequential(*layers) + self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.logits = nn.Sigmoid() + + def forward(self, x, cond=None): + x = self.encoder(x) + if cond is not None: + cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1)) + x = torch.cat([x, cond], dim=1) + x = self.shuffle(x) + x = self.logits(x) + return x diff --git a/comfy/ldm/cascade/stage_b.py b/comfy/ldm/cascade/stage_b.py new file mode 100644 index 00000000000..6d2c2223143 --- /dev/null +++ b/comfy/ldm/cascade/stage_b.py @@ -0,0 +1,257 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import math +import numpy as np +import torch +from torch import nn +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock + +class StageB(nn.Module): + def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280], + nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]], + block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280, + c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True, + t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.effnet_mapper = nn.Sequential( + operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.pixels_mapper = nn.Sequential( + operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device), + nn.GELU(), + operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device), + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings + # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip): + if len(clip.shape) == 2: + clip = clip.unsqueeze(1) + clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, effnet, clip, pixels=None, **kwargs): + if pixels is None: + pixels = x.new_zeros(x.size(0), 3, 8, 8) + + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip) + + # Model Blocks + x = self.embedding(x) + x = x + self.effnet_mapper( + nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True)) + x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear', + align_corners=True) + level_outputs = self._down_encode(x, r_embed, clip) + x = self._up_decode(level_outputs, r_embed, clip) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/ldm/cascade/stage_c.py b/comfy/ldm/cascade/stage_c.py new file mode 100644 index 00000000000..08e33aded22 --- /dev/null +++ b/comfy/ldm/cascade/stage_c.py @@ -0,0 +1,271 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +from torch import nn +import numpy as np +import math +from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock +# from .controlnet import ControlNetDeliverer + +class UpDownBlock2d(nn.Module): + def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None): + super().__init__() + assert mode in ['up', 'down'] + interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear', + align_corners=True) if enabled else nn.Identity() + mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device) + self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation]) + + def forward(self, x): + for block in self.blocks: + x = block(x) + return x + + +class StageC(nn.Module): + def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32], + blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'], + c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3, + dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None, + dtype=None, device=None, operations=None): + super().__init__() + self.dtype = dtype + self.c_r = c_r + self.t_conds = t_conds + self.c_clip_seq = c_clip_seq + if not isinstance(dropout, list): + dropout = [dropout] * len(c_hidden) + if not isinstance(self_attn, list): + self_attn = [self_attn] * len(c_hidden) + + # CONDITIONING + self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device) + self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device) + self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + + self.embedding = nn.Sequential( + nn.PixelUnshuffle(patch_size), + operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device), + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6) + ) + + def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == 'C': + return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'A': + return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'F': + return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations) + elif block_type == 'T': + return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations) + else: + raise Exception(f'Block type {block_type} not supported') + + # BLOCKS + # -- down blocks + self.down_blocks = nn.ModuleList() + self.down_downscalers = nn.ModuleList() + self.down_repeat_mappers = nn.ModuleList() + for i in range(len(c_hidden)): + if i > 0: + self.down_downscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.down_downscalers.append(nn.Identity()) + down_block = nn.ModuleList() + for _ in range(blocks[0][i]): + for block_type in level_config[i]: + block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i]) + down_block.append(block) + self.down_blocks.append(down_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[0][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.down_repeat_mappers.append(block_repeat_mappers) + + # -- up blocks + self.up_blocks = nn.ModuleList() + self.up_upscalers = nn.ModuleList() + self.up_repeat_mappers = nn.ModuleList() + for i in reversed(range(len(c_hidden))): + if i > 0: + self.up_upscalers.append(nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations) + )) + else: + self.up_upscalers.append(nn.Identity()) + up_block = nn.ModuleList() + for j in range(blocks[1][::-1][i]): + for k, block_type in enumerate(level_config[i]): + c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0 + block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], + self_attn=self_attn[i]) + up_block.append(block) + self.up_blocks.append(up_block) + if block_repeat is not None: + block_repeat_mappers = nn.ModuleList() + for _ in range(block_repeat[1][::-1][i] - 1): + block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device)) + self.up_repeat_mappers.append(block_repeat_mappers) + + # OUTPUT + self.clf = nn.Sequential( + LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device), + operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device), + nn.PixelShuffle(patch_size), + ) + + # --- WEIGHT INIT --- + # self.apply(self._init_weights) # General init + # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings + # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings + # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs + # nn.init.constant_(self.clf[1].weight, 0) # outputs + # + # # blocks + # for level_block in self.down_blocks + self.up_blocks: + # for block in level_block: + # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock): + # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0])) + # elif isinstance(block, TimestepBlock): + # for layer in block.modules(): + # if isinstance(layer, nn.Linear): + # nn.init.constant_(layer.weight, 0) + # + # def _init_weights(self, m): + # if isinstance(m, (nn.Conv2d, nn.Linear)): + # torch.nn.init.xavier_uniform_(m.weight) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def gen_r_embedding(self, r, max_positions=10000): + r = r * max_positions + half_dim = self.c_r // 2 + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = torch.cat([emb.sin(), emb.cos()], dim=1) + if self.c_r % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), mode='constant') + return emb + + def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img): + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_txt_pooled.shape) == 2: + clip_txt_pooled = clip_txt_pooled.unsqueeze(1) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1) + clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1) + clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1) + clip = self.clip_norm(clip) + return clip + + def _down_encode(self, x, r_embed, clip, cnet=None): + level_outputs = [] + block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers) + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip, cnet=None): + x = level_outputs[0] + block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers) + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, ResBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + ResBlock)): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)): + x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear', + align_corners=True) + if cnet is not None: + next_cnet = cnet() + if next_cnet is not None: + x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear', + align_corners=True) + x = block(x, skip) + elif isinstance(block, AttnBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + AttnBlock)): + x = block(x, clip) + elif isinstance(block, TimestepBlock) or ( + hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module, + TimestepBlock)): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs): + # Process the conditioning embeddings + r_embed = self.gen_r_embedding(r).to(dtype=x.dtype) + for c in self.t_conds: + t_cond = kwargs.get(c, torch.zeros_like(r)) + r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1) + clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img) + + # Model Blocks + x = self.embedding(x) + if cnet is not None: + cnet = ControlNetDeliverer(cnet) + level_outputs = self._down_encode(x, r_embed, clip, cnet) + x = self._up_decode(level_outputs, r_embed, clip, cnet) + return self.clf(x) + + def update_weights_ema(self, src_model, beta=0.999): + for self_params, src_params in zip(self.parameters(), src_model.parameters()): + self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta) + for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()): + self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta) diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py new file mode 100644 index 00000000000..98c9a0b6147 --- /dev/null +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -0,0 +1,96 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" +import torch +import torchvision +from torch import nn + + +# EfficientNet +class EfficientNetEncoder(nn.Module): + def __init__(self, c_latent=16): + super().__init__() + self.backbone = torchvision.models.efficientnet_v2_s().features.eval() + self.mapper = nn.Sequential( + nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 + ) + self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) + self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225])) + + def forward(self, x): + x = x * 0.5 + 0.5 + x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + o = self.mapper(self.backbone(x)) + print(o.shape) + return o + + +# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192 +class Previewer(nn.Module): + def __init__(self, c_in=16, c_hidden=512, c_out=3): + super().__init__() + self.blocks = nn.Sequential( + nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden), + + nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 2), + + nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + nn.GELU(), + nn.BatchNorm2d(c_hidden // 4), + + nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ) + + def forward(self, x): + return (self.blocks(x) - 0.5) * 2.0 + +class StageC_coder(nn.Module): + def __init__(self): + super().__init__() + self.previewer = Previewer() + self.encoder = EfficientNetEncoder() + + def encode(self, x): + return self.encoder(x) + + def decode(self, x): + return self.previewer(x) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 3e12886b07f..48399bc07e3 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,12 +1,9 @@ -from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any -from functools import partial - from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -117,7 +114,12 @@ def attention_basic(q, k, v, heads, mask=None): mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) else: - sim += mask + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + sim.add_(mask) # attention, what we cannot get enough of sim = sim.softmax(dim=-1) @@ -168,6 +170,13 @@ def attention_sub_quad(query, key, value, heads, mask=None): if query_chunk_size is None: query_chunk_size = 512 + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + hidden_states = efficient_dot_product_attention( query, key, @@ -177,6 +186,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): kv_chunk_size_min=kv_chunk_size_min, use_checkpoint=False, upcast_attention=upcast_attention, + mask=mask, ) hidden_states = hidden_states.to(dtype) @@ -225,6 +235,13 @@ def attention_split(q, k, v, heads, mask=None): raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + if mask is not None: + if len(mask.shape) == 2: + bs = 1 + else: + bs = mask.shape[0] + mask = mask.reshape(bs, -1, mask.shape[-2], mask.shape[-1]).expand(b, heads, -1, -1).reshape(-1, mask.shape[-2], mask.shape[-1]) + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) first_op_done = False cleared_cache = False @@ -239,6 +256,12 @@ def attention_split(q, k, v, heads, mask=None): else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + if mask is not None: + if len(mask.shape) == 2: + s1 += mask[i:end] + else: + s1 += mask[:, i:end] + s2 = s1.softmax(dim=-1).to(v.dtype) del s1 first_op_done = True @@ -294,11 +317,14 @@ def attention_xformers(q, k, v, heads, mask=None): (q, k, v), ) - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + if mask is not None: + pad = 8 - q.shape[1] % 8 + mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) + mask_out[:, :, :mask.shape[-1]] = mask + mask = mask_out[:, :, :mask.shape[-1]] + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - if exists(mask): - raise NotImplementedError out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) @@ -323,7 +349,6 @@ def attention_pytorch(q, k, v, heads, mask=None): optimized_attention = attention_basic -optimized_attention_masked = attention_basic if model_management.xformers_enabled(): print("Using xformers cross attention") @@ -339,15 +364,18 @@ def attention_pytorch(q, k, v, heads, mask=None): print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -if model_management.pytorch_attention_enabled(): - optimized_attention_masked = attention_pytorch +optimized_attention_masked = optimized_attention -def optimized_attention_for_device(device, mask=False): - if device == torch.device("cpu"): #TODO +def optimized_attention_for_device(device, mask=False, small_input=False): + if small_input: if model_management.pytorch_attention_enabled(): - return attention_pytorch + return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases else: return attention_basic + + if device == torch.device("cpu"): + return attention_sub_quad + if mask: return optimized_attention_masked diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 057dd16b250..998afd977ca 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -1,12 +1,9 @@ from abc import abstractmethod -import math -import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from functools import partial from .util import ( checkpoint, @@ -437,9 +434,6 @@ def __init__( operations=ops, ): super().__init__() - assert use_spatial_transformer == True, "use_spatial_transformer has to be true" - if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' if context_dim is not None: assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' @@ -456,7 +450,6 @@ def __init__( if num_head_channels == -1: assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels @@ -502,7 +495,7 @@ def __init__( if self.num_classes is not None: if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device) elif self.num_classes == "continuous": print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index 709a7f52e06..f5ac7c2f913 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -41,10 +41,14 @@ def register_schedule(self, beta_schedule="linear", timesteps=1000, self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + def q_sample(self, x_start, t, noise=None, seed=None): + if noise is None: + if seed is None: + noise = torch.randn_like(x_start) + else: + noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device) + return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise) def forward(self, x): return x, None @@ -69,12 +73,12 @@ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) return z, noise_level diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index ac7e27173bd..5a6aa7d77d1 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -98,7 +98,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) + betas = torch.clamp(betas, min=0, max=0.999) elif schedule == "squaredcos_cap_v2": # used for karlo prior # return early @@ -113,7 +113,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() + return betas def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): diff --git a/comfy/ldm/modules/encoders/noise_aug_modules.py b/comfy/ldm/modules/encoders/noise_aug_modules.py index b59bf204bc9..a5d86603016 100644 --- a/comfy/ldm/modules/encoders/noise_aug_modules.py +++ b/comfy/ldm/modules/encoders/noise_aug_modules.py @@ -15,21 +15,21 @@ def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): def scale(self, x): # re-normalize to centered mean and unit variance - x = (x - self.data_mean) * 1. / self.data_std + x = (x - self.data_mean.to(x.device)) * 1. / self.data_std.to(x.device) return x def unscale(self, x): # back to original data stats - x = (x * self.data_std) + self.data_mean + x = (x * self.data_std.to(x.device)) + self.data_mean.to(x.device) return x - def forward(self, x, noise_level=None): + def forward(self, x, noise_level=None, seed=None): if noise_level is None: noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() else: assert isinstance(noise_level, torch.Tensor) x = self.scale(x) - z = self.q_sample(x, noise_level) + z = self.q_sample(x, noise_level, seed=seed) z = self.unscale(z) noise_level = self.time_embed(noise_level) return z, noise_level diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 8e8e8054dfd..cb0896b0df5 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -61,6 +61,7 @@ def _summarize_chunk( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> AttnChunk: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -84,6 +85,8 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() attn_weights -= max_score + if mask is not None: + attn_weights += mask torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) @@ -96,11 +99,12 @@ def _query_chunk_attention( value: Tensor, summarize_chunk: SummarizeChunk, kv_chunk_size: int, + mask, ) -> Tensor: batch_x_heads, k_channels_per_head, k_tokens = key_t.shape _, _, v_channels_per_head = value.shape - def chunk_scanner(chunk_idx: int) -> AttnChunk: + def chunk_scanner(chunk_idx: int, mask) -> AttnChunk: key_chunk = dynamic_slice( key_t, (0, 0, chunk_idx), @@ -111,10 +115,13 @@ def chunk_scanner(chunk_idx: int) -> AttnChunk: (0, chunk_idx, 0), (batch_x_heads, kv_chunk_size, v_channels_per_head) ) - return summarize_chunk(query, key_chunk, value_chunk) + if mask is not None: + mask = mask[:,:,chunk_idx:chunk_idx + kv_chunk_size] + + return summarize_chunk(query, key_chunk, value_chunk, mask=mask) chunks: List[AttnChunk] = [ - chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + chunk_scanner(chunk, mask) for chunk in torch.arange(0, k_tokens, kv_chunk_size) ] acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) chunk_values, chunk_weights, chunk_max = acc_chunk @@ -135,6 +142,7 @@ def _get_attention_scores_no_kv_chunking( value: Tensor, scale: float, upcast_attention: bool, + mask, ) -> Tensor: if upcast_attention: with torch.autocast(enabled=False, device_type = 'cuda'): @@ -156,6 +164,8 @@ def _get_attention_scores_no_kv_chunking( beta=0, ) + if mask is not None: + attn_scores += mask try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores @@ -183,6 +193,7 @@ def efficient_dot_product_attention( kv_chunk_size_min: Optional[int] = None, use_checkpoint=True, upcast_attention=False, + mask = None, ): """Computes efficient dot-product attention given query, transposed key, and value. This is efficient version of attention presented in @@ -209,13 +220,22 @@ def efficient_dot_product_attention( if kv_chunk_size_min is not None: kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + if mask is not None and len(mask.shape) == 2: + mask = mask.unsqueeze(0) + def get_query_chunk(chunk_idx: int) -> Tensor: return dynamic_slice( query, (0, chunk_idx, 0), (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) ) - + + def get_mask_chunk(chunk_idx: int) -> Tensor: + if mask is None: + return None + chunk = min(query_chunk_size, q_tokens) + return mask[:,chunk_idx:chunk_idx + chunk] + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale, upcast_attention=upcast_attention) summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk compute_query_chunk_attn: ComputeQueryChunkAttn = partial( @@ -237,6 +257,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor: query=query, key_t=key_t, value=value, + mask=mask, ) # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, @@ -246,6 +267,7 @@ def get_query_chunk(chunk_idx: int) -> Tensor: query=get_query_chunk(i * query_chunk_size), key_t=key_t, value=value, + mask=get_mask_chunk(i * query_chunk_size) ) for i in range(math.ceil(q_tokens / query_chunk_size)) ], dim=1) return res diff --git a/comfy/model_base.py b/comfy/model_base.py index b3a1fcd51f0..fefce76378c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,21 +1,23 @@ import torch -from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep +from comfy.ldm.cascade.stage_c import StageC +from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation -from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep +from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation import comfy.model_management import comfy.conds import comfy.ops from enum import Enum -import contextlib from . import utils class ModelType(Enum): EPS = 1 V_PREDICTION = 2 V_PREDICTION_EDM = 3 + STABLE_CASCADE = 4 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling def model_sampling(model_config, model_type): @@ -28,6 +30,9 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM + elif model_type == ModelType.STABLE_CASCADE: + c = EPS + s = StableCascadeSampling class ModelSampling(s, c): pass @@ -36,7 +41,7 @@ class ModelSampling(s, c): class BaseModel(torch.nn.Module): - def __init__(self, model_config, model_type=ModelType.EPS, device=None): + def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() unet_config = model_config.unet_config @@ -49,7 +54,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None): operations = comfy.ops.manual_cast else: operations = comfy.ops.disable_weight_init - self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) + self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -78,8 +83,9 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans extra_conds = {} for o in kwargs: extra = kwargs[o] - if hasattr(extra, "to"): - extra = extra.to(dtype) + if hasattr(extra, "dtype"): + if extra.dtype != torch.int and extra.dtype != torch.long: + extra = extra.to(dtype) extra_conds[o] = extra model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() @@ -99,11 +105,29 @@ def extra_conds(self, **kwargs): if self.inpaint_model: concat_keys = ("mask", "masked_image") cond_concat = [] - denoise_mask = kwargs.get("denoise_mask", None) - latent_image = kwargs.get("latent_image", None) + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + concat_latent_image = kwargs.get("concat_latent_image", None) + if concat_latent_image is None: + concat_latent_image = kwargs.get("latent_image", None) + else: + concat_latent_image = self.process_latent_in(concat_latent_image) + noise = kwargs.get("noise", None) device = kwargs["device"] + if concat_latent_image.shape[1:] != noise.shape[1:]: + concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) + + if len(denoise_mask.shape) == len(noise.shape): + denoise_mask = denoise_mask[:,:1] + + denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1])) + if denoise_mask.shape[-2:] != noise.shape[-2:]: + denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") + denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) + def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space @@ -116,9 +140,9 @@ def blank_inpaint_image_like(latent_image): for ck in concat_keys: if denoise_mask is not None: if ck == "mask": - cond_concat.append(denoise_mask[:,:1].to(device)) + cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": - cond_concat.append(latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space + cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:,:1]) @@ -135,6 +159,10 @@ def blank_inpaint_image_like(latent_image): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) + cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) + if cross_attn_cnet is not None: + out['crossattn_controlnet'] = comfy.conds.CONDCrossAttn(cross_attn_cnet) + return out def load_model_weights(self, sd, unet_prefix=""): @@ -160,19 +188,28 @@ def process_latent_in(self, latent): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict, vae_state_dict): - clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict) + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + extra_sds = [] + if clip_state_dict is not None: + extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) + if vae_state_dict is not None: + extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) + if clip_vision_state_dict is not None: + extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) + unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict) + if self.get_dtype() == torch.float16: - clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16) - vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16) + extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds) if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) - return {**unet_state_dict, **vae_state_dict, **clip_state_dict} + for sd in extra_sds: + unet_state_dict.update(sd) + + return unet_state_dict def set_inpaint(self): self.inpaint_model = True @@ -191,7 +228,7 @@ def memory_required(self, input_shape): return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) -def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): +def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] noise_aug = [] @@ -200,7 +237,7 @@ def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge weight = unclip_cond["strength"] noise_augment = unclip_cond["noise_augmentation"] noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) noise_aug.append(noise_augment) @@ -226,11 +263,11 @@ def encode_adm(self, **kwargs): if unclip_conditioning is None: return torch.zeros((1, self.adm_channels)) else: - return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) + return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: - return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor)[:,:1280] + return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: return args["pooled_output"] @@ -364,3 +401,88 @@ def extra_conds(self, **kwargs): cross_attn = self.cc_projection(cross_attn) out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn) return out + +class SD_X4Upscaler(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): + super().__init__(model_config, model_type, device=device) + self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350) + + def extra_conds(self, **kwargs): + out = {} + + image = kwargs.get("concat_image", None) + noise = kwargs.get("noise", None) + noise_augment = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + seed = kwargs["seed"] - 10 + + noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) + + if image is None: + image = torch.zeros_like(noise)[:,:3] + + if image.shape[1:] != noise.shape[1:]: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + noise_level = torch.tensor([noise_level], device=device) + if noise_augment > 0: + image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed) + + image = utils.resize_to_batch_size(image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(image) + out['y'] = comfy.conds.CONDRegular(noise_level) + return out + +class StableCascade_C(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageC) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + + if "unclip_conditioning" in kwargs: + embeds = [] + for unclip_cond in kwargs["unclip_conditioning"]: + weight = unclip_cond["strength"] + embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight) + clip_img = torch.cat(embeds, dim=1) + else: + clip_img = torch.zeros((1, 1, 768)) + out["clip_img"] = comfy.conds.CONDRegular(clip_img) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + out["crp"] = comfy.conds.CONDRegular(torch.zeros((1,))) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) + return out + + +class StableCascade_B(BaseModel): + def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): + super().__init__(model_config, model_type, device=device, unet_model=StageB) + self.diffusion_model.eval().requires_grad_(False) + + def extra_conds(self, **kwargs): + out = {} + noise = kwargs.get("noise", None) + + clip_text_pooled = kwargs["pooled_output"] + if clip_text_pooled is not None: + out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + + #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched + prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) + + out["effnet"] = comfy.conds.CONDRegular(prior) + out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['clip'] = comfy.conds.CONDCrossAttn(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e3af422a310..8fca6d8c8e4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -28,13 +28,41 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None -def detect_unet_config(state_dict, key_prefix, dtype): +def detect_unet_config(state_dict, key_prefix): state_dict_keys = list(state_dict.keys()) + if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade + unet_config = {} + text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) + if text_mapper_name in state_dict_keys: + unet_config['stable_cascade_stage'] = 'c' + w = state_dict[text_mapper_name] + if w.shape[0] == 1536: #stage c lite + unet_config['c_cond'] = 1536 + unet_config['c_hidden'] = [1536, 1536] + unet_config['nhead'] = [24, 24] + unet_config['blocks'] = [[4, 12], [12, 4]] + elif w.shape[0] == 2048: #stage c full + unet_config['c_cond'] = 2048 + elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: + unet_config['stable_cascade_stage'] = 'b' + w = state_dict['{}down_blocks.1.0.channelwise.0.weight'.format(key_prefix)] + if w.shape[-1] == 640: + unet_config['c_hidden'] = [320, 640, 1280, 1280] + unet_config['nhead'] = [-1, -1, 20, 20] + unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] + elif w.shape[-1] == 576: #stage b lite + unet_config['c_hidden'] = [320, 576, 1152, 1152] + unet_config['nhead'] = [-1, 9, 18, 18] + unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] + unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] + + return unet_config + unet_config = { "use_checkpoint": False, "image_size": 32, - "out_channels": 4, "use_spatial_transformer": True, "legacy": False } @@ -46,10 +74,15 @@ def detect_unet_config(state_dict, key_prefix, dtype): else: unet_config["adm_in_channels"] = None - unet_config["dtype"] = dtype model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0] in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1] + out_key = '{}out.2.weight'.format(key_prefix) + if out_key in state_dict: + out_channels = state_dict[out_key].shape[0] + else: + out_channels = 4 + num_res_blocks = [] channel_mult = [] attention_resolutions = [] @@ -122,6 +155,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): transformer_depth_middle = -1 unet_config["in_channels"] = in_channels + unet_config["out_channels"] = out_channels unet_config["model_channels"] = model_channels unet_config["num_res_blocks"] = num_res_blocks unet_config["transformer_depth"] = transformer_depth @@ -153,8 +187,8 @@ def model_config_from_unet_config(unet_config): print("no match", unet_config) return None -def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False): - unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype) +def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): + unet_config = detect_unet_config(state_dict, unet_key_prefix) model_config = model_config_from_unet_config(unet_config) if model_config is None and use_base_if_no_match: return comfy.supported_models_base.BASE(unet_config) @@ -200,7 +234,7 @@ def convert_config(unet_config): return new_config -def unet_config_from_diffusers_unet(state_dict, dtype): +def unet_config_from_diffusers_unet(state_dict, dtype=None): match = {} transformer_depth = [] @@ -307,8 +341,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype): return convert_config(unet_config) return None -def model_config_from_diffusers_unet(state_dict, dtype): - unet_config = unet_config_from_diffusers_unet(state_dict, dtype) +def model_config_from_diffusers_unet(state_dict): + unet_config = unet_config_from_diffusers_unet(state_dict) if unet_config is not None: return model_config_from_unet_config(unet_config) return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 3adc42702c8..681208ea091 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -175,7 +175,7 @@ def is_nvidia(): if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported(): + if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: VAE_DTYPE = torch.bfloat16 if is_intel_xpu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: @@ -186,6 +186,9 @@ def is_nvidia(): if is_intel_xpu(): VAE_DTYPE = torch.bfloat16 +if args.cpu_vae: + VAE_DTYPE = torch.float32 + if args.fp16_vae: VAE_DTYPE = torch.float16 elif args.bf16_vae: @@ -259,6 +262,14 @@ def get_torch_device_name(device): current_loaded_models = [] +def module_size(module): + module_mem = 0 + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nelement() * t.element_size() + return module_mem + class LoadedModel: def __init__(self, model): self.model = model @@ -296,14 +307,14 @@ def model_load(self, lowvram_model_memory=0): if hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True - module_mem = 0 - sd = m.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem = module_size(m) if mem_counter + module_mem < lowvram_model_memory: m.to(self.device) mem_counter += module_mem + elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode + m.to(self.device) + mem_counter += module_size(m) + print("lowvram: loaded module regularly", m) self.model_accelerated = True @@ -476,7 +487,7 @@ def unet_inital_load_device(parameters, dtype): else: return cpu_dev -def unet_dtype(device=None, model_params=0): +def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if args.bf16_unet: return torch.bfloat16 if args.fp16_unet: @@ -485,21 +496,32 @@ def unet_dtype(device=None, model_params=0): return torch.float8_e4m3fn if args.fp8_e5m2_unet: return torch.float8_e5m2 - if should_use_fp16(device=device, model_params=model_params): - return torch.float16 + if should_use_fp16(device=device, model_params=model_params, manual_cast=True): + if torch.float16 in supported_dtypes: + return torch.float16 + if should_use_bf16(device, model_params=model_params, manual_cast=True): + if torch.bfloat16 in supported_dtypes: + return torch.bfloat16 return torch.float32 # None means no manual cast -def unet_manual_cast(weight_dtype, inference_device): +def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): if weight_dtype == torch.float32: return None - fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False) + fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) if fp16_supported and weight_dtype == torch.float16: return None - if fp16_supported: + bf16_supported = should_use_bf16(inference_device) + if bf16_supported and weight_dtype == torch.bfloat16: + return None + + if fp16_supported and torch.float16 in supported_dtypes: return torch.float16 + + elif bf16_supported and torch.bfloat16 in supported_dtypes: + return torch.bfloat16 else: return torch.float32 @@ -535,10 +557,8 @@ def text_encoder_dtype(device=None): if is_device_cpu(device): return torch.float16 - if should_use_fp16(device, prioritize_performance=False): - return torch.float16 - else: - return torch.float32 + return torch.float16 + def intermediate_device(): if args.gpu_only: @@ -547,6 +567,8 @@ def intermediate_device(): return torch.device("cpu") def vae_device(): + if args.cpu_vae: + return torch.device("cpu") return get_torch_device() def vae_offload_device(): @@ -673,19 +695,22 @@ def mps_mode(): global cpu_state return cpu_state == CPUState.MPS -def is_device_cpu(device): +def is_device_type(device, type): if hasattr(device, 'type'): - if (device.type == 'cpu'): + if (device.type == type): return True return False +def is_device_cpu(device): + return is_device_type(device, 'cpu') + def is_device_mps(device): - if hasattr(device, 'type'): - if (device.type == 'mps'): - return True - return False + return is_device_type(device, 'mps') -def should_use_fp16(device=None, model_params=0, prioritize_performance=True): +def is_device_cuda(device): + return is_device_type(device, 'cuda') + +def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled if device is not None: @@ -711,10 +736,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if is_intel_xpu(): return True - if torch.cuda.is_bf16_supported(): + if torch.version.hip: return True props = torch.cuda.get_device_properties("cuda") + if props.major >= 8: + return True + if props.major < 6: return False @@ -727,7 +755,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): if x in props.name.lower(): fp16_works = True - if fp16_works: + if fp16_works or manual_cast: free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) if (not prioritize_performance) or model_params * 4 > free_model_memory: return True @@ -743,6 +771,43 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True): return True +def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): + if device is not None: + if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow + return False + + if device is not None: #TODO not sure about mps bf16 support + if is_device_mps(device): + return False + + if FORCE_FP32: + return False + + if directml_enabled: + return False + + if cpu_mode() or mps_mode(): + return False + + if is_intel_xpu(): + return True + + if device is None: + device = torch.device("cuda") + + props = torch.cuda.get_device_properties(device) + if props.major >= 8: + return True + + bf16_works = torch.cuda.is_bf16_supported() + + if bf16_works or manual_cast: + free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory()) + if (not prioritize_performance) or model_params * 4 > free_model_memory: + return True + + return False + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6acb2d647c0..a88b737cca3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -28,13 +28,9 @@ def model_size(self): if self.size > 0: return self.size model_sd = self.model.state_dict() - size = 0 - for k in model_sd: - t = model_sd[k] - size += t.nelement() * t.element_size() - self.size = size + self.size = comfy.model_management.module_size(self.model) self.model_keys = set(model_sd.keys()) - return size + return self.size def clone(self): n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) @@ -178,40 +174,41 @@ def model_state_dict(self, filter_prefix=None): sd.pop(k) return sd - def patch_model(self, device_to=None): + def patch_model(self, device_to=None, patch_weights=True): for k in self.object_patches: old = getattr(self.model, k) if k not in self.object_patches_backup: self.object_patches_backup[k] = old setattr(self.model, k, self.object_patches[k]) - model_sd = self.model_state_dict() - for key in self.patches: - if key not in model_sd: - print("could not patch. key doesn't exist in model:", key) - continue + if patch_weights: + model_sd = self.model_state_dict() + for key in self.patches: + if key not in model_sd: + print("could not patch. key doesn't exist in model:", key) + continue - weight = model_sd[key] + weight = model_sd[key] - inplace_update = self.weight_inplace_update + inplace_update = self.weight_inplace_update - if key not in self.backup: - self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) + if key not in self.backup: + self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) - if device_to is not None: - temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) - else: - temp_weight = weight.to(torch.float32, copy=True) - out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr(self.model, key, out_weight) - del temp_weight + if device_to is not None: + temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + else: + temp_weight = weight.to(torch.float32, copy=True) + out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype) + if inplace_update: + comfy.utils.copy_to_param(self.model, key, out_weight) + else: + comfy.utils.set_attr(self.model, key, out_weight) + del temp_weight - if device_to is not None: - self.model.to(device_to) - self.current_device = device_to + if device_to is not None: + self.model.to(device_to) + self.current_device = device_to return self.model diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index cc8745c1064..ae42d81f200 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -1,5 +1,4 @@ import torch -import numpy as np from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule import math @@ -42,8 +41,7 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas - alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) - # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + alphas_cumprod = torch.cumprod(alphas, dim=0) timesteps, = betas.shape self.num_timesteps = int(timesteps) @@ -58,8 +56,8 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps self.set_sigmas(sigmas) def set_sigmas(self, sigmas): - self.register_buffer('sigmas', sigmas) - self.register_buffer('log_sigmas', sigmas.log()) + self.register_buffer('sigmas', sigmas.float()) + self.register_buffer('log_sigmas', sigmas.log().float()) @property def sigma_min(self): @@ -134,3 +132,56 @@ def percent_to_sigma(self, percent): log_sigma_min = math.log(self.sigma_min) return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) + +class StableCascadeSampling(ModelSamplingDiscrete): + def __init__(self, model_config=None): + super().__init__() + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + self.set_parameters(sampling_settings.get("shift", 1.0)) + + def set_parameters(self, shift=1.0, cosine_s=8e-3): + self.shift = shift + self.cosine_s = torch.tensor(cosine_s) + self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 + + #This part is just for compatibility with some schedulers in the codebase + self.num_timesteps = 1000 + sigmas = torch.empty((self.num_timesteps), dtype=torch.float32) + for x in range(self.num_timesteps): + t = x / self.num_timesteps + sigmas[x] = self.sigma(t) + + self.set_sigmas(sigmas) + + def sigma(self, timestep): + alpha_cumprod = (torch.cos((timestep + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 / self._init_alpha_cumprod) + + if self.shift != 1.0: + var = alpha_cumprod + logSNR = (var/(1-var)).log() + logSNR += 2 * torch.log(1.0 / torch.tensor(self.shift)) + alpha_cumprod = logSNR.sigmoid() + + alpha_cumprod = alpha_cumprod.clamp(0.0001, 0.9999) + return ((1 - alpha_cumprod) / alpha_cumprod) ** 0.5 + + def timestep(self, sigma): + var = 1 / ((sigma * sigma) + 1) + var = var.clamp(0, 1.0) + s, min_var = self.cosine_s.to(var.device), self._init_alpha_cumprod.to(var.device) + t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s + return t + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + + percent = 1.0 - percent + return self.sigma(torch.tensor(percent)) diff --git a/comfy/ops.py b/comfy/ops.py index f6f85de60a1..517688e8b92 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,5 +1,22 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch -from contextlib import contextmanager import comfy.model_management def cast_bias_weight(s, input): @@ -79,7 +96,11 @@ def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) + if self.weight is not None: + weight, bias = cast_bias_weight(self, input) + else: + weight = None + bias = None return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): @@ -88,6 +109,28 @@ def forward(self, *args, **kwargs): else: return super().forward(*args, **kwargs) + class ConvTranspose2d(torch.nn.ConvTranspose2d): + comfy_cast_weights = False + def reset_parameters(self): + return None + + def forward_comfy_cast_weights(self, input, output_size=None): + num_spatial_dims = 2 + output_padding = self._output_padding( + input, output_size, self.stride, self.padding, self.kernel_size, + num_spatial_dims, self.dilation) + + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.conv_transpose2d( + input, weight, bias, self.stride, self.padding, + output_padding, self.groups, self.dilation) + + def forward(self, *args, **kwargs): + if self.comfy_cast_weights: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) + @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: @@ -113,3 +156,6 @@ class GroupNorm(disable_weight_init.GroupNorm): class LayerNorm(disable_weight_init.LayerNorm): comfy_cast_weights = True + + class ConvTranspose2d(disable_weight_init.ConvTranspose2d): + comfy_cast_weights = True diff --git a/comfy/sample.py b/comfy/sample.py index 4b0d15c49d1..5c8a7d13039 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -28,7 +28,6 @@ def prepare_noise(latent_image, seed, noise_inds=None): def prepare_mask(noise_mask, shape, device): """ensures noise mask is of proper dimensions""" noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear") - noise_mask = noise_mask.round() noise_mask = torch.cat([noise_mask] * shape[1], dim=1) noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0]) noise_mask = noise_mask.to(device) diff --git a/comfy/samplers.py b/comfy/samplers.py index 0453c1f6fda..c795f208d80 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,13 +1,9 @@ from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc import torch -import enum import collections from comfy import model_management import math -from comfy import model_base -import comfy.utils -import comfy.conds def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) @@ -299,7 +295,7 @@ def simple_scheduler(model, steps): def ddim_scheduler(model, steps): s = model.model_sampling sigs = [] - ss = len(s.sigmas) // steps + ss = max(len(s.sigmas) // steps, 1) x = 1 while x < len(s.sigmas): sigs += [float(s.sigmas[x])] @@ -603,8 +599,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model latent_image = model.process_latent_in(latent_image) if hasattr(model, 'extra_conds'): - positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask) - negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask) + positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) #make sure each cond area has an opposite one with the same area for c in positive: @@ -639,7 +635,7 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): elif scheduler_name == "sgm_uniform": sigmas = normal_scheduler(model, steps, sgm=True) else: - print("error invalid scheduler", self.scheduler) + print("error invalid scheduler", scheduler_name) return sigmas def sampler_object(name): @@ -656,6 +652,7 @@ def sampler_object(name): class KSampler: SCHEDULERS = SCHEDULER_NAMES SAMPLERS = SAMPLER_NAMES + DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set(('dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2')) def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): self.model = model @@ -674,7 +671,7 @@ def calculate_sigmas(self, steps): sigmas = None discard_penultimate_sigma = False - if self.sampler in ['dpm_2', 'dpm_2_ancestral', 'uni_pc', 'uni_pc_bh2']: + if self.sampler in self.DISCARD_PENULTIMATE_SIGMA_SAMPLERS: steps += 1 discard_penultimate_sigma = True diff --git a/comfy/sd.py b/comfy/sd.py index 220637a05d7..00633e10768 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,10 +1,11 @@ import torch -import contextlib -import math +from enum import Enum from comfy import model_management -from .ldm.util import instantiate_from_config from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine +from .ldm.cascade.stage_a import StageA +from .ldm.cascade.stage_c_coder import StageC_coder + import yaml import comfy.utils @@ -157,6 +158,11 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) + self.downscale_ratio = 8 + self.upscale_ratio = 8 + self.latent_channels = 4 + self.process_input = lambda image: image * 2.0 - 1.0 + self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) if config is None: if "decoder.mid.block_1.mix_factor" in sd: @@ -169,9 +175,43 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) elif "taesd_decoder.1.weight" in sd: self.first_stage_model = comfy.taesd.taesd.TAESD() + elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade + self.first_stage_model = StageA() + self.downscale_ratio = 4 + self.upscale_ratio = 4 + #TODO + #self.memory_used_encode + #self.memory_used_decode + self.process_input = lambda image: image + self.process_output = lambda image: image + elif "backbone.1.0.block.0.1.num_batches_tracked" in sd: #effnet: encoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["encoder.{}".format(k)] = sd[k] + sd = new_sd + elif "blocks.11.num_batches_tracked" in sd: #previewer: decoder for stage c latent of stable cascade + self.first_stage_model = StageC_coder() + self.latent_channels = 16 + new_sd = {} + for k in sd: + new_sd["previewer.{}".format(k)] = sd[k] + sd = new_sd + elif "encoder.backbone.1.0.block.0.1.num_batches_tracked" in sd: #combined effnet and previewer for stable cascade + self.first_stage_model = StageC_coder() + self.downscale_ratio = 32 + self.latent_channels = 16 else: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + + if 'encoder.down.2.downsample.conv.weight' not in sd: #Stable diffusion x4 upscaler VAE + ddconfig['ch_mult'] = [1, 2, 4] + self.downscale_ratio = 4 + self.upscale_ratio = 4 + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) @@ -196,18 +236,27 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def vae_encode_crop_pixels(self, pixels): + x = (pixels.shape[1] // self.downscale_ratio) * self.downscale_ratio + y = (pixels.shape[2] // self.downscale_ratio) * self.downscale_ratio + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % self.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % self.downscale_ratio) // 2 + pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] + return pixels + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() - output = torch.clamp(( - (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) + - comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar)) - / 3.0) / 2.0, min=0.0, max=1.0) + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + output = self.process_output( + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + + comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar)) + / 3.0) return output def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -216,10 +265,10 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() - samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar) + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples /= 3.0 return samples @@ -231,10 +280,10 @@ def decode(self, samples_in): batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.upscale_ratio), round(samples_in.shape[3] * self.upscale_ratio)), device=self.output_device) for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) + pixel_samples[x:x+batch_number] = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float()) except model_management.OOM_EXCEPTION as e: print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) @@ -248,6 +297,7 @@ def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): return output.movedim(1,-1) def encode(self, pixel_samples): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) @@ -255,9 +305,9 @@ def encode(self, pixel_samples): free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) - samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device) + samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device) + pixels_in = self.process_input(pixel_samples[x:x+batch_number]).to(self.vae_dtype).to(self.device) samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() except model_management.OOM_EXCEPTION as e: @@ -267,6 +317,7 @@ def encode(self, pixel_samples): return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + pixel_samples = self.vae_encode_crop_pixels(pixel_samples) model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) @@ -293,8 +344,11 @@ def load_style_model(ckpt_path): model.load_state_dict(model_data) return StyleModel(model) +class CLIPType(Enum): + STABLE_DIFFUSION = 1 + STABLE_CASCADE = 2 -def load_clip(ckpt_paths, embedding_directory=None): +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): clip_data = [] for p in ckpt_paths: clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) @@ -310,8 +364,12 @@ class EmptyClass: clip_target.params = {} if len(clip_data) == 1: if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]: - clip_target.clip = sdxl_clip.SDXLRefinerClipModel - clip_target.tokenizer = sdxl_clip.SDXLTokenizer + if clip_type == CLIPType.STABLE_CASCADE: + clip_target.clip = sdxl_clip.StableCascadeClipModel + clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer + else: + clip_target.clip = sdxl_clip.SDXLRefinerClipModel + clip_target.tokenizer = sdxl_clip.SDXLTokenizer elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: clip_target.clip = sd2_clip.SD2ClipModel clip_target.tokenizer = sd2_clip.SD2Tokenizer @@ -434,15 +492,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o clip_target = None parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") - unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) class WeightsLoader(torch.nn.Module): pass - model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) - model_config.set_manual_cast(manual_cast_dtype) + model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -458,7 +516,7 @@ class WeightsLoader(torch.nn.Module): model.load_model_weights(sd, "model.diffusion_model.") if output_vae: - vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd) @@ -466,10 +524,13 @@ class WeightsLoader(torch.nn.Module): w = WeightsLoader() clip_target = model_config.clip_target() if clip_target is not None: - clip = CLIP(clip_target, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model sd = model_config.process_clip_state_dict(sd) - load_model_weights(w, sd) + if any(k.startswith('cond_stage_model.') for k in sd): + clip = CLIP(clip_target, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_model_weights(w, sd) + else: + print("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") left_over = sd.keys() if len(left_over) > 0: @@ -488,16 +549,15 @@ def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) - if "input_blocks.0.0.weight" in sd: #ldm - model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) + if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade + model_config = model_detection.model_config_from_unet(sd, "") if model_config is None: return None new_sd = sd else: #diffusers - model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) + model_config = model_detection.model_config_from_diffusers_unet(sd) if model_config is None: return None @@ -509,8 +569,11 @@ def load_unet_state_dict(sd): #load unet in diffusers format new_sd[diffusers_keys[k]] = sd.pop(k) else: print(diffusers_keys[k], k) + offload_device = model_management.unet_offload_device() - model_config.set_manual_cast(manual_cast_dtype) + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") @@ -527,7 +590,14 @@ def load_unet(unet_path): raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) return model -def save_checkpoint(output_path, model, clip, vae, metadata=None): - model_management.load_models_gpu([model, clip.load_model()]) - sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) +def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None): + clip_sd = None + load_models = [model] + if clip is not None: + load_models.append(clip.load_model()) + clip_sd = clip.get_sd() + + model_management.load_models_gpu(load_models) + clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None + sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd) comfy.utils.save_torch_file(sd, output_path, metadata=metadata) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 6ffef515ede..8287ad2e8b8 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -6,7 +6,6 @@ import traceback import zipfile from . import model_management -import contextlib import comfy.clip_model import json @@ -68,7 +67,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): ] def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel, - special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32 + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -89,7 +88,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le self.special_tokens = special_tokens self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False + self.enable_attention_masks = enable_attention_masks self.layer_norm_hidden_state = layer_norm_hidden_state if layer == "hidden": diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py index b35056bb9d6..3ce5c7e05e6 100644 --- a/comfy/sdxl_clip.py +++ b/comfy/sdxl_clip.py @@ -64,3 +64,25 @@ def load_sd(self, sd): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None): super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG) + + +class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): + def __init__(self, tokenizer_path=None, embedding_directory=None): + super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') + +class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None): + super().__init__(embedding_directory=embedding_directory, clip_name="g", tokenizer=StableCascadeClipGTokenizer) + +class StableCascadeClipG(sd1_clip.SDClipModel): + def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") + super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) + + def load_sd(self, sd): + return super().load_sd(sd) + +class StableCascadeClipModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None): + super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 251bf6ace86..1a673646ee5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -278,6 +278,88 @@ def get_model(self, state_dict, prefix="", device=None): def clip_target(self): return None +class SD_X4Upscaler(SD20): + unet_config = { + "context_dim": 1024, + "model_channels": 256, + 'in_channels': 7, + "use_linear_in_transformer": True, + "adm_in_channels": None, + "use_temporal_attention": False, + } + + unet_extra_config = { + "disable_self_attentions": [True, True, True, False], + "num_classes": 1000, + "num_heads": 8, + "num_head_channels": -1, + } + + latent_format = latent_formats.SD_X4 + + sampling_settings = { + "linear_start": 0.0001, + "linear_end": 0.02, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SD_X4Upscaler(self, device=device) + return out + +class Stable_Cascade_C(supported_models_base.BASE): + unet_config = { + "stable_cascade_stage": 'c', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_Prior + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 2.0, + } + + def process_unet_state_dict(self, state_dict): + key_list = list(state_dict.keys()) + for y in ["weight", "bias"]: + suffix = "in_proj_{}".format(y) + keys = filter(lambda a: a.endswith(suffix), key_list) + for k_from in keys: + weights = state_dict.pop(k_from) + prefix = k_from[:-(len(suffix) + 1)] + shape_from = weights.shape[0] // 3 + for x in range(3): + p = ["to_q", "to_k", "to_v"] + k_to = "{}.{}.{}".format(prefix, p[x], y) + state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)] + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_C(self, device=device) + return out + + def clip_target(self): + return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel) + +class Stable_Cascade_B(Stable_Cascade_C): + unet_config = { + "stable_cascade_stage": 'b', + } + + unet_extra_config = {} + + latent_format = latent_formats.SC_B + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + sampling_settings = { + "shift": 1.0, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.StableCascade_B(self, device=device) + return out + -models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega] +models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 49087d23e5d..3bd4f9c6523 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -21,13 +21,15 @@ class BASE: noise_aug_config = None sampling_settings = {} latent_format = latent_formats.LatentFormat + vae_key_prefix = ["first_stage_model."] + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] manual_cast_dtype = None @classmethod def matches(s, unet_config): for k in s.unet_config: - if s.unet_config[k] != unet_config[k]: + if k not in unet_config or s.unet_config[k] != unet_config[k]: return False return True @@ -65,6 +67,12 @@ def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def process_clip_vision_state_dict_for_saving(self, state_dict): + replace_prefix = {} + if self.clip_vision_prefix is not None: + replace_prefix[""] = self.clip_vision_prefix + return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model.diffusion_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -73,5 +81,6 @@ def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": "first_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_manual_cast(self, manual_cast_dtype): + def set_inference_dtype(self, dtype, manual_cast_dtype): + self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 46f3097a2a1..8f96c54e56a 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -7,9 +7,10 @@ import torch.nn as nn import comfy.utils +import comfy.ops def conv(n_in, n_out, **kwargs): - return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + return comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 3, padding=1, **kwargs) class Clamp(nn.Module): def forward(self, x): @@ -19,7 +20,7 @@ class Block(nn.Module): def __init__(self, n_in, n_out): super().__init__() self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) - self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) diff --git a/comfy/utils.py b/comfy/utils.py index f8026ddab9d..04cf76ed678 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -169,6 +169,8 @@ def transformers_convert(sd, prefix_from, prefix_to, number): } def unet_to_diffusers(unet_config): + if "num_res_blocks" not in unet_config: + return {} num_res_blocks = unet_config["num_res_blocks"] channel_mult = unet_config["channel_mult"] transformer_depth = unet_config["transformer_depth"][:] @@ -413,6 +415,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device) for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): + x = max(0, min(s.shape[-1] - overlap, x)) + y = max(0, min(s.shape[-2] - overlap, y)) s_in = s[:,:,y:y+tile_y,x:x+tile_x] ps = function(s_in).to(output_device) diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py new file mode 100644 index 00000000000..646fefa1746 --- /dev/null +++ b/comfy_extras/nodes_cond.py @@ -0,0 +1,25 @@ + + +class CLIPTextEncodeControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True})}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "_for_testing/conditioning" + + def encode(self, clip, conditioning, text): + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['cross_attn_controlnet'] = cond + n[1]['pooled_output_controlnet'] = pooled + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet +} diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 8791d8ae3c4..99f9ea7dcef 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -13,6 +13,7 @@ def INPUT_TYPES(s): {"model": ("MODEL",), "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), } } RETURN_TYPES = ("SIGMAS",) @@ -20,8 +21,14 @@ def INPUT_TYPES(s): FUNCTION = "get_sigmas" - def get_sigmas(self, model, scheduler, steps): - sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, steps).cpu() + def get_sigmas(self, model, scheduler, steps, denoise): + total_steps = steps + if denoise < 1.0: + total_steps = int(steps/denoise) + + comfy.model_management.load_models_gpu([model]) + sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu() + sigmas = sigmas[-(steps + 1):] return (sigmas, ) @@ -98,6 +105,7 @@ def INPUT_TYPES(s): def get_sigmas(self, model, steps, denoise): start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] + comfy.model_management.load_models_gpu([model]) sigmas = model.model.model_sampling.sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) return (sigmas, ) diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 7512b841d74..7764aa0b013 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -34,7 +34,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] @@ -73,7 +73,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, b1, b2, s1, s2): model_channels = model.model.model_config.unet_config["model_channels"] diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index e7446b2e540..ae55d23dd06 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -32,29 +32,29 @@ def INPUT_TYPES(s): RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "_for_testing" + CATEGORY = "model_patches" def patch(self, model, tile_size, swap_size, max_depth, scale_depth): model_channels = model.model.model_config.unet_config["model_channels"] - apply_to = set() - temp = model_channels - for x in range(max_depth + 1): - apply_to.add(temp) - temp *= 2 - latent_tile_size = max(32, tile_size) // 8 self.temp = None def hypertile_in(q, k, v, extra_options): - if q.shape[-1] in apply_to: + model_chans = q.shape[-2] + orig_shape = extra_options['original_shape'] + apply_to = [] + for i in range(max_depth + 1): + apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) + + if model_chans in apply_to: shape = extra_options["original_shape"] aspect_ratio = shape[-1] / shape[-2] hw = q.size(1) h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) - factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1 + factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 nh = random_divisor(h, latent_tile_size * factor, swap_size) nw = random_divisor(w, latent_tile_size * factor, swap_size) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index aa80f5269a3..8f638bf8fc1 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -48,6 +48,25 @@ def repeat(self, image, amount): s = image.repeat((amount, 1,1,1)) return (s,) +class ImageFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}), + "length": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "frombatch" + + CATEGORY = "image/batch" + + def frombatch(self, image, batch_index, length): + s_in = image + batch_index = min(s_in.shape[0] - 1, batch_index) + length = min(s_in.shape[0] - batch_index, length) + s = s_in[batch_index:batch_index + length].clone() + return (s,) + class SaveAnimatedWEBP: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -170,6 +189,7 @@ def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", pr NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, + "ImageFromBatch": ImageFromBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, "SaveAnimatedPNG": SaveAnimatedPNG, } diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2eefc4c555d..eabae088516 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -122,10 +122,34 @@ def batch(self, samples1, samples2): samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) return (samples_out,) +class LatentBatchSeedBehavior: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, "LatentBatch": LatentBatch, + "LatentBatchSeedBehavior": LatentBatchSeedBehavior, } diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 541ce8fa5cc..ac7c1c17a16 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -99,6 +99,32 @@ class ModelSamplingAdvanced(sampling_base, sampling_type): m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class ModelSamplingContinuousEDM: @classmethod def INPUT_TYPES(s): @@ -171,5 +197,6 @@ def rescale_cfg(args): NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingStableCascade": ModelSamplingStableCascade, "RescaleCFG": RescaleCFG, } diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index dad1dd6378d..d594cf490b6 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -119,6 +119,48 @@ def merge(self, model1, model2, **kwargs): m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) return (m, ) +def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + + enable_modelspec = True + if isinstance(model.model, comfy.model_base.SDXL): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + elif isinstance(model.model, comfy.model_base.SDXLRefiner): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + else: + enable_modelspec = False + + if enable_modelspec: + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "sgm" + metadata["modelspec.title"] = "{} {}".format(filename, counter) + + #TODO: + # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", + # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", + # "v2-inpainting" + + if model.model.model_type == comfy.model_base.ModelType.EPS: + metadata["modelspec.predict_key"] = "epsilon" + elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: + metadata["modelspec.predict_key"] = "v" + + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata) + class CheckpointSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -137,46 +179,7 @@ def INPUT_TYPES(s): CATEGORY = "advanced/model_merging" def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - prompt_info = "" - if prompt is not None: - prompt_info = json.dumps(prompt) - - metadata = {} - - enable_modelspec = True - if isinstance(model.model, comfy.model_base.SDXL): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" - elif isinstance(model.model, comfy.model_base.SDXLRefiner): - metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" - else: - enable_modelspec = False - - if enable_modelspec: - metadata["modelspec.sai_model_spec"] = "1.0.0" - metadata["modelspec.implementation"] = "sgm" - metadata["modelspec.title"] = "{} {}".format(filename, counter) - - #TODO: - # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", - # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", - # "v2-inpainting" - - if model.model.model_type == comfy.model_base.ModelType.EPS: - metadata["modelspec.predict_key"] = "epsilon" - elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: - metadata["modelspec.predict_key"] = "v" - - if not args.disable_metadata: - metadata["prompt"] = prompt_info - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - output_checkpoint = f"{filename}_{counter:05}_.safetensors" - output_checkpoint = os.path.join(full_output_folder, output_checkpoint) - - comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata) + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) return {} class CLIPSave: diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py new file mode 100644 index 00000000000..90130142b28 --- /dev/null +++ b/comfy_extras/nodes_photomaker.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import folder_paths +import comfy.clip_model +import comfy.clip_vision +import comfy.ops + +# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "image_size": 224, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "hidden_act": "quick_gelu", +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = operations.LayerNorm(in_dim) + self.fc1 = operations.Linear(in_dim, hidden_dim) + self.fc2 = operations.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim, operations): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations) + self.layer_norm = operations.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): + def __init__(self): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast) + self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048, comfy.ops.manual_cast) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[2] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + + RETURN_TYPES = ("PHOTOMAKER",) + FUNCTION = "load_photomaker_model" + + CATEGORY = "_for_testing/photomaker" + + def load_photomaker_model(self, photomaker_model_name): + photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) + photomaker_model = PhotoMakerIDEncoder() + data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) + if "id_encoder" in data: + data = data["id_encoder"] + photomaker_model.load_state_dict(data) + return (photomaker_model,) + + +class PhotoMakerEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker": ("PHOTOMAKER",), + "image": ("IMAGE",), + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_photomaker" + + CATEGORY = "_for_testing/photomaker" + + def apply_photomaker(self, photomaker, image, clip, text): + special_token = "photomaker" + pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() + try: + index = text.split(" ").index(special_token) + 1 + except ValueError: + index = -1 + tokens = clip.tokenize(text, return_word_ids=True) + out_tokens = {} + for k in tokens: + out_tokens[k] = [] + for t in tokens[k]: + f = list(filter(lambda x: x[2] != index, t)) + while len(f) < len(t): + f.append(t[-1]) + out_tokens[k].append(f) + + cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) + + if index > 0: + token_index = index - 1 + num_id_images = 1 + class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)] + out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), + class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) + else: + out = cond + + return ([[out, {"pooled_output": pooled}]], ) + + +NODE_CLASS_MAPPINGS = { + "PhotoMakerLoader": PhotoMakerLoader, + "PhotoMakerEncode": PhotoMakerEncode, +} + diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 71660f8a525..cb5c7d22817 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -33,6 +33,7 @@ def INPUT_TYPES(s): CATEGORY = "image/postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image2 = image2.to(image1.device) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 450ac3eeacd..bbd3808078d 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -58,7 +58,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = math.ceil(math.sqrt(lh * lw / hw1)) + ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] # Reshape @@ -143,6 +143,8 @@ def post_cfg_function(args): sigma = args["sigma"] model_options = args["model_options"] x = args["input"] + if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding + return cfg_result # create the adversarially blurred image degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py new file mode 100644 index 00000000000..28c1cb0f171 --- /dev/null +++ b/comfy_extras/nodes_sdupscale.py @@ -0,0 +1,47 @@ +import torch +import nodes +import comfy.utils + +class SD_4XUpscale_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/upscale_diffusion" + + def encode(self, images, positive, negative, scale_ratio, noise_augmentation): + width = max(1, round(images.shape[-2] * scale_ratio)) + height = max(1, round(images.shape[-3] * scale_ratio)) + + pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center") + + out_cp = [] + out_cn = [] + + for t in positive: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cp.append(n) + + for t in negative: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cn.append(n) + + latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) + return (out_cp, out_cn, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, +} diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index c6791d8de2a..4375d8f960e 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -46,13 +46,57 @@ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevat encode_pixels = pixels[:,:,:,:3] t = vae.encode(encode_pixels) cam_embeds = camera_embeddings(elevation, azimuth) - cond = torch.cat([pooled, cam_embeds.repeat((pooled.shape[0], 1, 1))], dim=-1) + cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1) positive = [[cond, {"concat_latent_image": t}]] negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] latent = torch.zeros([batch_size, 4, height // 8, width // 8]) return (positive, negative, {"samples":latent}) +class StableZero123_Conditioning_Batched: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + cam_embeds = [] + for i in range(batch_size): + cam_embeds.append(camera_embeddings(elevation, azimuth)) + elevation += elevation_batch_increment + azimuth += azimuth_batch_increment + + cam_embeds = torch.cat(cam_embeds, dim=0) + cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + + NODE_CLASS_MAPPINGS = { "StableZero123_Conditioning": StableZero123_Conditioning, + "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, } diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py new file mode 100644 index 00000000000..5d31c1e59c7 --- /dev/null +++ b/comfy_extras/nodes_stable_cascade.py @@ -0,0 +1,74 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import nodes + + +class StableCascade_EmptyLatentImage: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "compression": ("INT", {"default": 42, "min": 32, "max": 64, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}) + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, width, height, compression, batch_size=1): + c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) + b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageB_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "conditioning": ("CONDITIONING",), + "stage_c": ("LATENT",), + }} + RETURN_TYPES = ("CONDITIONING",) + + FUNCTION = "set_prior" + + CATEGORY = "_for_testing/stable_cascade" + + def set_prior(self, conditioning, stage_c): + c = [] + for t in conditioning: + d = t[1].copy() + d['stable_cascade_prior'] = stage_c['samples'] + n = [t[0], d] + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, + "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, +} diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 26a717a3836..a5262565282 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -3,6 +3,7 @@ import comfy.utils import comfy.sd import folder_paths +import comfy_extras.nodes_model_merging class ImageOnlyCheckpointLoader: @@ -78,10 +79,26 @@ def linear_cfg(args): m.set_model_sampler_cfg_function(linear_cfg) return (m, ) +class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): + CATEGORY = "_for_testing" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip_vision": ("CLIP_VISION",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + + def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): + comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 733014f3c7d..7ce271ec617 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -6,6 +6,8 @@ class Example: ------------- INPUT_TYPES (dict): Tell the main program input parameters of nodes. + IS_CHANGED: + optional method to control when the node is re executed. Attributes ---------- @@ -89,6 +91,17 @@ class Example: image = 1.0 - image return (image,) + """ + The node will always be re executed if any of the inputs change but + this method can be used to force the node to execute again even when the inputs don't change. + You can make this node return a number or a string. This value will be compared to the one returned the last time the node was + executed, if it is different the node will be executed again. + This method is used in the core repo for the LoadImage node where they return the image hash as a string, if the image hash + changes between executions the LoadImage node is executed again. + """ + #@classmethod + #def IS_CHANGED(s, image, string_field, int_field, float_field, print_to_screen): + # return "" # A dictionary that contains all nodes you want to export with their names # NOTE: names should be globally unique diff --git a/custom_nodes/websocket_image_save.py.disabled b/custom_nodes/websocket_image_save.py.disabled new file mode 100644 index 00000000000..b85a5de8be0 --- /dev/null +++ b/custom_nodes/websocket_image_save.py.disabled @@ -0,0 +1,49 @@ +from PIL import Image, ImageOps +from io import BytesIO +import numpy as np +import struct +import comfy.utils +import time + +#You can use this node to save full size images through the websocket, the +#images will be sent in exactly the same format as the image previews: as +#binary images on the websocket with a 8 byte header indicating the type +#of binary message (first 4 bytes) and the image format (next 4 bytes). + +#The reason this node is disabled by default is because there is a small +#issue when using it with the default ComfyUI web interface: When generating +#batches only the last image will be shown in the UI. + +#Note that no metadata will be put in the images saved with this node. + +class SaveImageWebsocket: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ),} + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "image" + + def save_images(self, images): + pbar = comfy.utils.ProgressBar(images.shape[0]) + step = 0 + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pbar.update_absolute(step, images.shape[0], ("PNG", img, None)) + step += 1 + + return {} + + def IS_CHANGED(s, images): + return time.time() + +NODE_CLASS_MAPPINGS = { + "SaveImageWebsocket": SaveImageWebsocket, +} diff --git a/execution.py b/execution.py index 7ad171313b0..00908eadd46 100644 --- a/execution.py +++ b/execution.py @@ -1,12 +1,11 @@ -import os import sys import copy -import json import logging import threading import heapq import traceback -import gc +import inspect +from typing import List, Literal, NamedTuple, Optional import torch import nodes @@ -267,11 +266,21 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item class PromptExecutor: def __init__(self, server): + self.server = server + self.reset() + + def reset(self): self.outputs = {} self.object_storage = {} self.outputs_ui = {} + self.status_messages = [] + self.success = True self.old_prompt = {} - self.server = server + + def add_message(self, event, data, broadcast: bool): + self.status_messages.append((event, data)) + if self.server.client_id is not None or broadcast: + self.server.send_sync(event, data, self.server.client_id) def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): node_id = error["node_id"] @@ -286,23 +295,22 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e "node_type": class_type, "executed": list(executed), } - self.server.send_sync("execution_interrupted", mes, self.server.client_id) + self.add_message("execution_interrupted", mes, broadcast=True) else: - if self.server.client_id is not None: - mes = { - "prompt_id": prompt_id, - "node_id": node_id, - "node_type": class_type, - "executed": list(executed), - - "exception_message": error["exception_message"], - "exception_type": error["exception_type"], - "traceback": error["traceback"], - "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], - } - self.server.send_sync("execution_error", mes, self.server.client_id) + mes = { + "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, + "executed": list(executed), + "exception_message": error["exception_message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.add_message("execution_error", mes, broadcast=False) + # Next, remove the subsequent outputs since they will not be executed to_delete = [] for o in self.outputs: @@ -323,8 +331,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): else: self.server.client_id = None - if self.server.client_id is not None: - self.server.send_sync("execution_start", { "prompt_id": prompt_id}, self.server.client_id) + self.status_messages = [] + self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): #delete cached outputs if nodes don't exist for them @@ -357,8 +365,9 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): del d comfy.model_management.cleanup_models() - if self.server.client_id is not None: - self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) + self.add_message("execution_cached", + { "nodes": list(current_outputs) , "prompt_id": prompt_id}, + broadcast=False) executed = set() output_node_id = None to_execute = [] @@ -374,8 +383,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): # This call shouldn't raise anything if there's an error deep in # the actual SD code, instead it will report the node where the # error was raised - success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) - if success is not True: + self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui, self.object_storage) + if self.success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) break @@ -402,6 +411,10 @@ def validate_inputs(prompt, item, validated): errors = [] valid = True + validate_function_inputs = [] + if hasattr(obj_class, "VALIDATE_INPUTS"): + validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + for x in required_inputs: if x not in inputs: error = { @@ -531,29 +544,7 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if hasattr(obj_class, "VALIDATE_INPUTS"): - input_data_all = get_input_data(inputs, obj_class, unique_id) - #ret = obj_class.VALIDATE_INPUTS(**input_data_all) - ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for i, r in enumerate(ret): - if r is not True: - details = f"{x}" - if r is not False: - details += f" - {str(r)}" - - error = { - "type": "custom_validation_failed", - "message": "Custom validation failed for node", - "details": details, - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - else: + if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -580,6 +571,35 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue + if len(validate_function_inputs) > 0: + input_data_all = get_input_data(inputs, obj_class, unique_id) + input_filtered = {} + for x in input_data_all: + if x in validate_function_inputs: + input_filtered[x] = input_data_all[x] + + #ret = obj_class.VALIDATE_INPUTS(**input_filtered) + ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") + for x in input_filtered: + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f" - {str(r)}" + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if len(errors) > 0 or valid is not True: ret = (False, errors, unique_id) else: @@ -694,6 +714,7 @@ def __init__(self, server): self.queue = [] self.currently_running = {} self.history = {} + self.flags = {} server.prompt_queue = self def put(self, item): @@ -715,14 +736,27 @@ def get(self, timeout=None): self.server.queue_updated() return (item, i) - def task_done(self, item_id, outputs): + class ExecutionStatus(NamedTuple): + status_str: Literal['success', 'error'] + completed: bool + messages: List[str] + + def task_done(self, item_id, outputs, + status: Optional['PromptQueue.ExecutionStatus']): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: self.history.pop(next(iter(self.history))) - self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } - for o in outputs: - self.history[prompt[1]]["outputs"][o] = outputs[o] + + status_dict: Optional[dict] = None + if status is not None: + status_dict = copy.deepcopy(status._asdict()) + + self.history[prompt[1]] = { + "prompt": prompt, + "outputs": copy.deepcopy(outputs), + 'status': status_dict, + } self.server.queue_updated() def get_current_queue(self): @@ -780,3 +814,17 @@ def wipe_history(self): def delete_history_item(self, id_to_delete): with self.mutex: self.history.pop(id_to_delete, None) + + def set_flag(self, name, data): + with self.mutex: + self.flags[name] = data + self.not_empty.notify() + + def get_flags(self, reset=True): + with self.mutex: + if reset: + ret = self.flags + self.flags = {} + return ret + else: + return self.flags.copy() diff --git a/folder_paths.py b/folder_paths.py index 98704945e56..f1bf40f8c04 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -29,11 +29,14 @@ folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) +folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions) + folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") filename_list_cache = {} @@ -137,15 +140,27 @@ def recursive_search(directory, excluded_dir_names=None): excluded_dir_names = [] result = [] - dirs = {directory: os.path.getmtime(directory)} + dirs = {} + + # Attempt to add the initial directory to dirs with error handling + try: + dirs[directory] = os.path.getmtime(directory) + except FileNotFoundError: + print(f"Warning: Unable to access {directory}. Skipping this path.") + for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] for file_name in filenames: relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory) result.append(relative_path) + for d in subdirs: path = os.path.join(dirpath, d) - dirs[path] = os.path.getmtime(path) + try: + dirs[path] = os.path.getmtime(path) + except FileNotFoundError: + print(f"Warning: Unable to access {path}. Skipping this path.") + continue return result, dirs def filter_files_extensions(files, extensions): @@ -184,8 +199,7 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] - if time.perf_counter() < (out[2] + 0.5): - return out + for x in out[1]: time_modified = out[1][x] folder = x diff --git a/main.py b/main.py index f6aeceed2af..69d9bce6cb7 100644 --- a/main.py +++ b/main.py @@ -97,7 +97,7 @@ def prompt_worker(q, server): gc_collect_interval = 10.0 while True: - timeout = None + timeout = 1000.0 if need_gc: timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) @@ -106,9 +106,16 @@ def prompt_worker(q, server): item, item_id = queue_item execution_start_time = time.perf_counter() prompt_id = item[1] + server.last_prompt_id = prompt_id + e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True - q.task_done(item_id, e.outputs_ui) + q.task_done(item_id, + e.outputs_ui, + status=execution.PromptQueue.ExecutionStatus( + status_str='success' if e.success else 'error', + completed=e.success, + messages=e.status_messages)) if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) @@ -116,6 +123,19 @@ def prompt_worker(q, server): execution_time = current_time - execution_start_time print("Prompt executed in {:.2f} seconds".format(execution_time)) + flags = q.get_flags() + free_memory = flags.get("free_memory", False) + + if flags.get("unload_models", free_memory): + comfy.model_management.unload_all_models() + need_gc = True + last_gc_collect = 0 + + if free_memory: + e.reset() + need_gc = True + last_gc_collect = 0 + if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: @@ -131,7 +151,9 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): def hijack_progress(server): def hook(value, total, preview_image): comfy.model_management.throw_exception_if_processing_interrupted() - server.send_sync("progress", {"value": value, "max": total}, server.client_id) + progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id} + + server.send_sync("progress", progress, server.client_id) if preview_image is not None: server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) diff --git a/models/photomaker/put_photomaker_models_here b/models/photomaker/put_photomaker_models_here new file mode 100644 index 00000000000..e69de29bb2d diff --git a/nodes.py b/nodes.py index 027bf55d994..6666413d642 100644 --- a/nodes.py +++ b/nodes.py @@ -184,6 +184,26 @@ def append(self, conditioning, width, height, x, y, strength): c.append(n) return (c, ) +class ConditioningSetAreaStrength: + @classmethod + def INPUT_TYPES(s): + return {"required": {"conditioning": ("CONDITIONING", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "conditioning" + + def append(self, conditioning, strength): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['strength'] = strength + c.append(n) + return (c, ) + + class ConditioningSetMask: @classmethod def INPUT_TYPES(s): @@ -289,18 +309,7 @@ def INPUT_TYPES(s): CATEGORY = "latent" - @staticmethod - def vae_encode_crop_pixels(pixels): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 - if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 - pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :] - return pixels - def encode(self, vae, pixels): - pixels = self.vae_encode_crop_pixels(pixels) t = vae.encode(pixels[:,:,:,:3]) return ({"samples":t}, ) @@ -316,7 +325,6 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" def encode(self, vae, pixels, tile_size): - pixels = VAEEncode.vae_encode_crop_pixels(pixels) t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) return ({"samples":t}, ) @@ -330,14 +338,14 @@ def INPUT_TYPES(s): CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // 8) * 8 - y = (pixels.shape[2] // 8) * 8 + x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio + y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % 8) // 2 - y_offset = (pixels.shape[2] % 8) // 2 + x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 + y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] @@ -359,6 +367,62 @@ def encode(self, vae, pixels, mask, grow_mask_by=6): return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) + +class InpaintModelConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + "mask": ("MASK", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/inpaint" + + def encode(self, positive, negative, pixels, vae, mask): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") + + orig_pixels = pixels + pixels = orig_pixels.clone() + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] + + m = (1.0 - mask.round()).squeeze(1) + for i in range(3): + pixels[:,:,:,i] -= 0.5 + pixels[:,:,:,i] *= m + pixels[:,:,:,i] += 0.5 + concat_latent = vae.encode(pixels) + orig_latent = vae.encode(orig_pixels) + + out_latent = {} + + out_latent["samples"] = orig_latent + out_latent["noise_mask"] = mask + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + d["concat_mask"] = mask + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + + class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -778,15 +842,20 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("clip"), ), + "type": (["stable_diffusion", "stable_cascade"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" CATEGORY = "advanced/loaders" - def load_clip(self, clip_name): + def load_clip(self, clip_name, type="stable_diffusion"): + clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION + if type == "stable_cascade": + clip_type = comfy.sd.CLIPType.STABLE_CASCADE + clip_path = folder_paths.get_full_path("clip", clip_name) - clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings")) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) return (clip,) class DualCLIPLoader: @@ -1358,7 +1427,7 @@ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pngi filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() - for image in images: + for (batch_number, image) in enumerate(images): i = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) metadata = None @@ -1370,7 +1439,8 @@ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pngi for x in extra_pnginfo: metadata.add_text(x, json.dumps(extra_pnginfo[x])) - file = f"{filename}_{counter:05}_.png" + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.png" img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, @@ -1415,6 +1485,8 @@ def load_image(self, image): output_masks = [] for i in ImageSequence.Iterator(img): i = ImageOps.exif_transpose(i) + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 image = torch.from_numpy(image)[None,] @@ -1470,6 +1542,8 @@ def load_image(self, image, channel): i = Image.open(image_path) i = ImageOps.exif_transpose(i) if i.getbands() != ("R", "G", "B", "A"): + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) i = i.convert("RGBA") mask = None c = channel[0].upper() @@ -1491,13 +1565,10 @@ def IS_CHANGED(s, image, channel): return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, image, channel): + def VALIDATE_INPUTS(s, image): if not folder_paths.exists_annotated_filepath(image): return "Invalid image file: {}".format(image) - if channel not in s._color_channels: - return "Invalid color channel: {}".format(channel) - return True class ImageScale: @@ -1627,10 +1698,11 @@ def INPUT_TYPES(s): def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() - new_image = torch.zeros( + new_image = torch.ones( (d1, d2 + top + bottom, d3 + left + right, d4), dtype=torch.float32, - ) + ) * 0.5 + new_image[:, top:top + d2, left:left + d3, :] = image mask = torch.ones( @@ -1696,6 +1768,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): "ConditioningConcat": ConditioningConcat, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, + "ConditioningSetAreaStrength": ConditioningSetAreaStrength, "ConditioningSetMask": ConditioningSetMask, "KSamplerAdvanced": KSamplerAdvanced, "SetLatentNoiseMask": SetLatentNoiseMask, @@ -1722,6 +1795,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): "unCLIPCheckpointLoader": unCLIPCheckpointLoader, "GLIGENLoader": GLIGENLoader, "GLIGENTextBoxApply": GLIGENTextBoxApply, + "InpaintModelConditioning": InpaintModelConditioning, "CheckpointLoader": CheckpointLoader, "DiffusersLoader": DiffusersLoader, @@ -1832,22 +1906,27 @@ def load_custom_node(module_path, ignore=set()): print(f"Cannot import {module_path} module for custom nodes:", e) return False -def load_custom_nodes(): +def load_custom_nodes(public_mark: str = '_public/'): base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] - for custom_node_path in node_paths: + to_install_possible_modules = {} + + for custom_node_path in sorted(node_paths, key=lambda x: public_mark in x, reverse=True): possible_modules = os.listdir(os.path.realpath(custom_node_path)) if "__pycache__" in possible_modules: possible_modules.remove("__pycache__") for possible_module in possible_modules: - module_path = os.path.join(custom_node_path, possible_module) - if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue - if module_path.endswith(".disabled"): continue - time_before = time.perf_counter() - success = load_custom_node(module_path, base_node_names) - node_import_times.append((time.perf_counter() - time_before, module_path, success)) + to_install_possible_modules[possible_module] = (custom_node_path, possible_module) + + for custom_node_path, possible_module in to_install_possible_modules.values(): + module_path = os.path.join(custom_node_path, possible_module) + if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue + if module_path.endswith(".disabled"): continue + time_before = time.perf_counter() + success = load_custom_node(module_path, base_node_names) + node_import_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_import_times) > 0: print("\nImport times for custom nodes:") @@ -1883,6 +1962,10 @@ def init_custom_nodes(): "nodes_sag.py", "nodes_perpneg.py", "nodes_stable3d.py", + "nodes_sdupscale.py", + "nodes_photomaker.py", + "nodes_cond.py", + "nodes_stable_cascade.py", ] for node_file in extras_files: diff --git a/requirements.txt b/requirements.txt index da1fbb27e0c..e804618e715 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch torchsde +torchvision einops transformers>=4.25.1 safetensors>=0.3.0 diff --git a/server.py b/server.py index d9cf517f99b..d5c64128d99 100644 --- a/server.py +++ b/server.py @@ -30,6 +30,7 @@ import comfy.utils import comfy.model_management +from app.user_manager import UserManager class BinaryEventTypes: PREVIEW_IMAGE = 1 @@ -72,6 +73,7 @@ def __init__(self, loop): mimetypes.init() mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8' + self.user_manager = UserManager() self.supports = ["custom_nodes_from_web"] self.prompt_queue = None self.loop = loop @@ -509,6 +511,17 @@ async def post_interrupt(request): nodes.interrupt_processing() return web.Response(status=200) + @routes.post("/free") + async def post_free(request): + json_data = await request.json() + unload_models = json_data.get("unload_models", False) + free_memory = json_data.get("free_memory", False) + if unload_models: + self.prompt_queue.set_flag("unload_models", unload_models) + if free_memory: + self.prompt_queue.set_flag("free_memory", free_memory) + return web.Response(status=200) + @routes.post("/history") async def post_history(request): json_data = await request.json() @@ -523,6 +536,7 @@ async def post_history(request): return web.Response(status=200) def add_routes(self): + self.user_manager.add_routes(self.routes) self.app.add_routes(self.routes) for name, dir in nodes.EXTENSION_WEB_DIRS.items(): @@ -586,7 +600,8 @@ async def send_bytes(self, event, data, sid=None): message = self.encode_bytes(event, data) if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_bytes, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_bytes, message) @@ -595,7 +610,8 @@ async def send_json(self, event, data, sid=None): message = {"type": event, "data": data} if sid is None: - for ws in self.sockets.values(): + sockets = list(self.sockets.values()) + for ws in sockets: await send_socket_catch_exception(ws.send_json, message) elif sid in self.sockets: await send_socket_catch_exception(self.sockets[sid].send_json, message) @@ -618,8 +634,6 @@ async def start(self, address, port, verbose=True, call_on_start=None): site = web.TCPSite(runner, address, port) await site.start() - if address == '': - address = '0.0.0.0' if verbose: print("Starting server\n") print("To see the GUI go to: http://{}:{}".format(address, port)) diff --git a/tests-ui/babel.config.json b/tests-ui/babel.config.json index 526ddfd8df1..f27d6c397e5 100644 --- a/tests-ui/babel.config.json +++ b/tests-ui/babel.config.json @@ -1,3 +1,4 @@ { - "presets": ["@babel/preset-env"] + "presets": ["@babel/preset-env"], + "plugins": ["babel-plugin-transform-import-meta"] } diff --git a/tests-ui/package-lock.json b/tests-ui/package-lock.json index 35911cd7ffd..0f409ca2484 100644 --- a/tests-ui/package-lock.json +++ b/tests-ui/package-lock.json @@ -11,6 +11,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } @@ -2591,6 +2592,19 @@ "@babel/core": "^7.4.0 || ^8.0.0-0 <8.0.0" } }, + "node_modules/babel-plugin-transform-import-meta": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/babel-plugin-transform-import-meta/-/babel-plugin-transform-import-meta-2.2.1.tgz", + "integrity": "sha512-AxNh27Pcg8Kt112RGa3Vod2QS2YXKKJ6+nSvRtv7qQTJAdx0MZa4UHZ4lnxHUWA2MNbLuZQv5FVab4P1CoLOWw==", + "dev": true, + "dependencies": { + "@babel/template": "^7.4.4", + "tslib": "^2.4.0" + }, + "peerDependencies": { + "@babel/core": "^7.10.0" + } + }, "node_modules/babel-preset-current-node-syntax": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/babel-preset-current-node-syntax/-/babel-preset-current-node-syntax-1.0.1.tgz", @@ -5233,6 +5247,12 @@ "node": ">=12" } }, + "node_modules/tslib": { + "version": "2.6.2", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", + "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==", + "dev": true + }, "node_modules/type-detect": { "version": "4.0.8", "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", diff --git a/tests-ui/package.json b/tests-ui/package.json index e7b60ad8e75..ae7e490843a 100644 --- a/tests-ui/package.json +++ b/tests-ui/package.json @@ -24,6 +24,7 @@ "devDependencies": { "@babel/preset-env": "^7.22.20", "@types/jest": "^29.5.5", + "babel-plugin-transform-import-meta": "^2.2.1", "jest": "^29.7.0", "jest-environment-jsdom": "^29.7.0" } diff --git a/tests-ui/tests/users.test.js b/tests-ui/tests/users.test.js new file mode 100644 index 00000000000..5e07307306e --- /dev/null +++ b/tests-ui/tests/users.test.js @@ -0,0 +1,295 @@ +// @ts-check +/// +const { start } = require("../utils"); +const lg = require("../utils/litegraph"); + +describe("users", () => { + beforeEach(() => { + lg.setup(global); + }); + + afterEach(() => { + lg.teardown(global); + }); + + function expectNoUserScreen() { + // Ensure login isnt visible + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection["style"].display).toBe("none"); + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + } + + describe("multi-user", () => { + function mockAddStylesheet() { + const utils = require("../../web/scripts/utils"); + utils.addStylesheet = jest.fn().mockReturnValue(Promise.resolve()); + } + + async function waitForUserScreenShow() { + mockAddStylesheet(); + + // Wait for "show" to be called + const { UserSelectionScreen } = require("../../web/scripts/ui/userSelection"); + let resolve, reject; + const fn = UserSelectionScreen.prototype.show; + const p = new Promise((res, rej) => { + resolve = res; + reject = rej; + }); + jest.spyOn(UserSelectionScreen.prototype, "show").mockImplementation(async (...args) => { + const res = fn(...args); + await new Promise(process.nextTick); // wait for promises to resolve + resolve(); + return res; + }); + // @ts-ignore + setTimeout(() => reject("timeout waiting for UserSelectionScreen to be shown."), 500); + await p; + await new Promise(process.nextTick); // wait for promises to resolve + } + + async function testUserScreen(onShown, users) { + if (!users) { + users = {}; + } + const starting = start({ + resetEnv: true, + userConfig: { storage: "server", users }, + }); + + // Ensure no current user + expect(localStorage["Comfy.userId"]).toBeFalsy(); + expect(localStorage["Comfy.userName"]).toBeFalsy(); + + await waitForUserScreenShow(); + + const selection = document.querySelectorAll("#comfy-user-selection")?.[0]; + expect(selection).toBeTruthy(); + + // Ensure login is visible + expect(window.getComputedStyle(selection)?.display).not.toBe("none"); + // Ensure menu is hidden + const menu = document.querySelectorAll(".comfy-menu")?.[0]; + expect(window.getComputedStyle(menu)?.display).toBe("none"); + + const isCreate = await onShown(selection); + + // Submit form + selection.querySelectorAll("form")[0].submit(); + await new Promise(process.nextTick); // wait for promises to resolve + + // Wait for start + const s = await starting; + + // Ensure login is removed + expect(document.querySelectorAll("#comfy-user-selection")).toHaveLength(0); + expect(window.getComputedStyle(menu)?.display).not.toBe("none"); + + // Ensure settings + templates are saved + const { api } = require("../../web/scripts/api"); + expect(api.createUser).toHaveBeenCalledTimes(+isCreate); + expect(api.storeSettings).toHaveBeenCalledTimes(+isCreate); + expect(api.storeUserData).toHaveBeenCalledTimes(+isCreate); + if (isCreate) { + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(s.app.isNewUserSession).toBeTruthy(); + } else { + expect(s.app.isNewUserSession).toBeFalsy(); + } + + return { users, selection, ...s }; + } + + it("allows user creation if no users", async () => { + const { users } = await testUserScreen((selection) => { + // Ensure we have no users flag added + expect(selection.classList.contains("no-users")).toBeTruthy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User"; + + return true; + }); + + expect(users).toStrictEqual({ + "Test User!": "Test User", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User!"); + expect(localStorage["Comfy.userName"]).toBe("Test User"); + }); + it("allows user creation if no current user but other users", async () => { + const users = { + "Test User 2!": "Test User 2", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Enter a username + const input = selection.getElementsByTagName("input")[0]; + input.focus(); + input.value = "Test User 3"; + return true; + }, users); + + expect(users).toStrictEqual({ + "Test User 2!": "Test User 2", + "Test User 3!": "Test User 3", + }); + + expect(localStorage["Comfy.userId"]).toBe("Test User 3!"); + expect(localStorage["Comfy.userName"]).toBe("Test User 3"); + }); + it("allows user selection if no current user but other users", async () => { + const users = { + "A!": "A", + "B!": "B", + "C!": "C", + }; + + await testUserScreen((selection) => { + expect(selection.classList.contains("no-users")).toBeFalsy(); + + // Check user list + const select = selection.getElementsByTagName("select")[0]; + const options = select.getElementsByTagName("option"); + expect( + [...options] + .filter((o) => !o.disabled) + .reduce((p, n) => { + p[n.getAttribute("value")] = n.textContent; + return p; + }, {}) + ).toStrictEqual(users); + + // Select an option + select.focus(); + select.value = options[2].value; + + return false; + }, users); + + expect(users).toStrictEqual(users); + + expect(localStorage["Comfy.userId"]).toBe("B!"); + expect(localStorage["Comfy.userName"]).toBe("B"); + }); + it("doesnt show user screen if current user", async () => { + const starting = start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + await new Promise(process.nextTick); // wait for promises to resolve + + expectNoUserScreen(); + + await starting; + }); + it("allows user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { + storage: "server", + users: { + "User!": "User", + }, + }, + localStorage: { + "Comfy.userId": "User!", + "Comfy.userName": "User", + }, + }); + + // cant actually test switching user easily but can check the setting is present + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeTruthy(); + }); + }); + describe("single-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledTimes(1); + expect(api.storeUserData).toHaveBeenCalledWith("comfy.templates.json", null, { stringify: false }); + expect(app.isNewUserSession).toBeTruthy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); + describe("browser-user", () => { + it("doesnt show user creation if no default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: false, storage: "browser" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt show user creation if default user", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "server" }, + }); + expectNoUserScreen(); + + // It should store the settings + const { api } = require("../../web/scripts/api"); + expect(api.storeSettings).toHaveBeenCalledTimes(0); + expect(api.storeUserData).toHaveBeenCalledTimes(0); + expect(app.isNewUserSession).toBeFalsy(); + }); + it("doesnt allow user switching", async () => { + const { app } = await start({ + resetEnv: true, + userConfig: { migrated: true, storage: "browser" }, + }); + expectNoUserScreen(); + + expect(app.ui.settings.settingsLookup["Comfy.SwitchUser"]).toBeFalsy(); + }); + }); +}); diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 6a08e8594e9..74b6cf93dbc 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -1,10 +1,18 @@ const { mockApi } = require("./setup"); const { Ez } = require("./ezgraph"); const lg = require("./litegraph"); +const fs = require("fs"); +const path = require("path"); + +const html = fs.readFileSync(path.resolve(__dirname, "../../web/index.html")) /** * - * @param { Parameters[0] & { resetEnv?: boolean, preSetup?(app): Promise } } config + * @param { Parameters[0] & { + * resetEnv?: boolean, + * preSetup?(app): Promise, + * localStorage?: Record + * } } config * @returns */ export async function start(config = {}) { @@ -12,12 +20,18 @@ export async function start(config = {}) { jest.resetModules(); jest.resetAllMocks(); lg.setup(global); + localStorage.clear(); + sessionStorage.clear(); } + Object.assign(localStorage, config.localStorage ?? {}); + document.body.innerHTML = html; + mockApi(config); const { app } = require("../../web/scripts/app"); config.preSetup?.(app); await app.setup(); + return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; } diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js index dd150214a34..e46258943ed 100644 --- a/tests-ui/utils/setup.js +++ b/tests-ui/utils/setup.js @@ -18,9 +18,21 @@ function* walkSync(dir) { */ /** - * @param { { mockExtensions?: string[], mockNodeDefs?: Record } } config + * @param {{ + * mockExtensions?: string[], + * mockNodeDefs?: Record, +* settings?: Record +* userConfig?: {storage: "server" | "browser", users?: Record, migrated?: boolean }, +* userData?: Record + * }} config */ -export function mockApi({ mockExtensions, mockNodeDefs } = {}) { +export function mockApi(config = {}) { + let { mockExtensions, mockNodeDefs, userConfig, settings, userData } = { + userConfig, + settings: {}, + userData: {}, + ...config, + }; if (!mockExtensions) { mockExtensions = Array.from(walkSync(path.resolve("../web/extensions/core"))) .filter((x) => x.endsWith(".js")) @@ -40,6 +52,26 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) { getNodeDefs: jest.fn(() => mockNodeDefs), init: jest.fn(), apiURL: jest.fn((x) => "../../web/" + x), + createUser: jest.fn((username) => { + if(username in userConfig.users) { + return { status: 400, json: () => "Duplicate" } + } + userConfig.users[username + "!"] = username; + return { status: 200, json: () => username + "!" } + }), + getUserConfig: jest.fn(() => userConfig ?? { storage: "browser", migrated: false }), + getSettings: jest.fn(() => settings), + storeSettings: jest.fn((v) => Object.assign(settings, v)), + getUserData: jest.fn((f) => { + if (f in userData) { + return { status: 200, json: () => userData[f] }; + } else { + return { status: 404 }; + } + }), + storeUserData: jest.fn((file, data) => { + userData[file] = data; + }), }; jest.mock("../../web/scripts/api", () => ({ get api() { diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 4cf1f7621b9..0b0763d1d49 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -1,6 +1,7 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; import { mergeIfValid } from "./widgetInputs.js"; +import { ManageGroupDialog } from "./groupNodeManage.js"; const GROUP = Symbol(); @@ -61,11 +62,7 @@ class GroupNodeBuilder { ); return; case Workflow.InUse.Registered: - if ( - !confirm( - "An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?" - ) - ) { + if (!confirm("A group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?")) { return; } break; @@ -151,6 +148,8 @@ export class GroupNodeConfig { this.primitiveDefs = {}; this.widgetToPrimitive = {}; this.primitiveToWidget = {}; + this.nodeInputs = {}; + this.outputVisibility = []; } async registerType(source = "workflow") { @@ -158,6 +157,7 @@ export class GroupNodeConfig { output: [], output_name: [], output_is_list: [], + output_is_hidden: [], name: source + "/" + this.name, display_name: this.name, category: "group nodes" + ("/" + source), @@ -277,8 +277,7 @@ export class GroupNodeConfig { } if (input.widget) { const targetDef = globalDefs[node.type]; - const targetWidget = - targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; + const targetWidget = targetDef.input.required[input.widget.name] ?? targetDef.input.optional[input.widget.name]; const widget = [targetWidget[0], config]; const res = mergeIfValid( @@ -330,7 +329,8 @@ export class GroupNodeConfig { } getInputConfig(node, inputName, seenInputs, config, extra) { - let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + const customConfig = this.nodeData.config?.[node.index]?.input?.[inputName]; + let name = customConfig?.name ?? node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; let key = name; let prefix = ""; // Special handling for primitive to include the title if it is set rather than just "value" @@ -349,14 +349,14 @@ export class GroupNodeConfig { } if (config[0] === "IMAGEUPLOAD") { if (!extra) extra = {}; - extra.widget = `${prefix}${config[1]?.widget ?? "image"}`; + extra.widget = this.oldToNewWidgetMap[node.index]?.[config[1]?.widget ?? "image"] ?? "image"; } if (extra) { config = [config[0], { ...config[1], ...extra }]; } - return { name, config }; + return { name, config, customConfig }; } processWidgetInputs(inputs, node, inputNames, seenInputs) { @@ -366,9 +366,7 @@ export class GroupNodeConfig { for (const inputName of inputNames) { let widgetType = app.getWidgetType(inputs[inputName], inputName); if (widgetType) { - const convertedIndex = node.inputs?.findIndex( - (inp) => inp.name === inputName && inp.widget?.name === inputName - ); + const convertedIndex = node.inputs?.findIndex((inp) => inp.name === inputName && inp.widget?.name === inputName); if (convertedIndex > -1) { // This widget has been converted to a widget // We need to store this in the correct position so link ids line up @@ -424,6 +422,7 @@ export class GroupNodeConfig { } processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) { + this.nodeInputs[node.index] = {}; for (let i = 0; i < slots.length; i++) { const inputName = slots[i]; if (linksTo[i]) { @@ -432,7 +431,11 @@ export class GroupNodeConfig { continue; } - const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + const { name, config, customConfig } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + + this.nodeInputs[node.index][inputName] = name; + if(customConfig?.visible === false) continue; + this.nodeDef.input.required[name] = config; inputMap[i] = this.inputCount++; } @@ -452,6 +455,7 @@ export class GroupNodeConfig { const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], { defaultInput: true, }); + this.nodeDef.input.required[name] = config; this.newToOldWidgetMap[name] = { node, inputName }; @@ -477,9 +481,7 @@ export class GroupNodeConfig { this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); // Converted inputs have to be processed after all other nodes as they'll be at the end of the list - this.#convertedToProcess.push(() => - this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) - ); + this.#convertedToProcess.push(() => this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs)); return inputMapping; } @@ -490,8 +492,12 @@ export class GroupNodeConfig { // Add outputs for (let outputId = 0; outputId < def.output.length; outputId++) { const linksFrom = this.linksFrom[node.index]; - if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) { - // This output is linked internally so we can skip it + // If this output is linked internally we flag it to hide + const hasLink = linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]; + const customConfig = this.nodeData.config?.[node.index]?.output?.[outputId]; + const visible = customConfig?.visible ?? !hasLink; + this.outputVisibility.push(visible); + if (!visible) { continue; } @@ -500,11 +506,15 @@ export class GroupNodeConfig { this.nodeDef.output.push(def.output[outputId]); this.nodeDef.output_is_list.push(def.output_is_list[outputId]); - let label = def.output_name?.[outputId] ?? def.output[outputId]; - const output = node.outputs.find((o) => o.name === label); - if (output?.label) { - label = output.label; + let label = customConfig?.name; + if (!label) { + label = def.output_name?.[outputId] ?? def.output[outputId]; + const output = node.outputs.find((o) => o.name === label); + if (output?.label) { + label = output.label; + } } + let name = label; if (name in seenOutputs) { const prefix = `${node.title ?? node.type} `; @@ -677,6 +687,25 @@ export class GroupNodeHandler { return this.innerNodes; }; + this.node.recreate = async () => { + const id = this.node.id; + const sz = this.node.size; + const nodes = this.node.convertToNodes(); + + const groupNode = LiteGraph.createNode(this.node.type); + groupNode.id = id; + + // Reuse the existing nodes for this instance + groupNode.setInnerNodes(nodes); + groupNode[GROUP].populateWidgets(); + app.graph.add(groupNode); + groupNode.size = [Math.max(groupNode.size[0], sz[0]), Math.max(groupNode.size[1], sz[1])]; + + // Remove all converted nodes and relink them + groupNode[GROUP].replaceNodes(nodes); + return groupNode; + }; + this.node.convertToNodes = () => { const addInnerNodes = () => { const backup = localStorage.getItem("litegrapheditor_clipboard"); @@ -769,6 +798,7 @@ export class GroupNodeHandler { const slot = node.inputs[groupSlotId]; if (slot.link == null) continue; const link = app.graph.links[slot.link]; + if (!link) continue; // connect this node output to the input of another node const originNode = app.graph.getNodeById(link.origin_id); originNode.connect(link.origin_slot, newNode, +innerInputId); @@ -806,12 +836,23 @@ export class GroupNodeHandler { let optionIndex = options.findIndex((o) => o.content === "Outputs"); if (optionIndex === -1) optionIndex = options.length; else optionIndex++; - options.splice(optionIndex, 0, null, { - content: "Convert to nodes", - callback: () => { - return this.convertToNodes(); + options.splice( + optionIndex, + 0, + null, + { + content: "Convert to nodes", + callback: () => { + return this.convertToNodes(); + }, }, - }); + { + content: "Manage Group Node", + callback: () => { + new ManageGroupDialog(app).show(this.type); + }, + } + ); }; // Draw custom collapse icon to identity this as a group @@ -843,6 +884,7 @@ export class GroupNodeHandler { const r = onDrawForeground?.apply?.(this, arguments); if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) { const n = groupData.nodes[this.runningInternalNodeId]; + if(!n) return; const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`; ctx.save(); ctx.font = "12px sans-serif"; @@ -865,6 +907,31 @@ export class GroupNodeHandler { return onExecutionStart?.apply(this, arguments); }; + const self = this; + const onNodeCreated = this.node.onNodeCreated; + this.node.onNodeCreated = function () { + if (!this.widgets) { + return; + } + const config = self.groupData.nodeData.config; + if (config) { + for (const n in config) { + const inputs = config[n]?.input; + for (const w in inputs) { + if (inputs[w].visible !== false) continue; + const widgetName = self.groupData.oldToNewWidgetMap[n][w]; + const widget = this.widgets.find((w) => w.name === widgetName); + if (widget) { + widget.type = "hidden"; + widget.computeSize = () => [0, -4]; + } + } + } + } + + return onNodeCreated?.apply(this, arguments); + }; + function handleEvent(type, getId, getEvent) { const handler = ({ detail }) => { const id = getId(detail); @@ -902,6 +969,26 @@ export class GroupNodeHandler { api.removeEventListener("executing", executing); api.removeEventListener("executed", executed); }; + + this.node.refreshComboInNode = (defs) => { + // Update combo widget options + for (const widgetName in this.groupData.newToOldWidgetMap) { + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget?.type === "combo") { + const old = this.groupData.newToOldWidgetMap[widgetName]; + const def = defs[old.node.type]; + const input = def?.input?.required?.[old.inputName] ?? def?.input?.optional?.[old.inputName]; + if (!input) continue; + + widget.options.values = input[0]; + + if (old.inputName !== "image" && !widget.options.values.includes(widget.value)) { + widget.value = widget.options.values[0]; + widget.callback(widget.value); + } + } + } + }; } updateInnerWidgets() { @@ -927,13 +1014,15 @@ export class GroupNodeHandler { continue; } else if (innerNode.type === "Reroute") { const rerouteLinks = this.groupData.linksFrom[old.node.index]; - for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { - const node = this.innerNodes[targetNodeId]; - const input = node.inputs[targetSlot]; - if (input.widget) { - const widget = node.widgets?.find((w) => w.name === input.widget.name); - if (widget) { - widget.value = newValue; + if (rerouteLinks) { + for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { + const node = this.innerNodes[targetNodeId]; + const input = node.inputs[targetSlot]; + if (input.widget) { + const widget = node.widgets?.find((w) => w.name === input.widget.name); + if (widget) { + widget.value = newValue; + } } } } @@ -975,7 +1064,7 @@ export class GroupNodeHandler { const [, , targetNodeId, targetNodeSlot] = link; const targetNode = this.groupData.nodeData.nodes[targetNodeId]; const inputs = targetNode.inputs; - const targetWidget = inputs?.[targetNodeSlot].widget; + const targetWidget = inputs?.[targetNodeSlot]?.widget; if (!targetWidget) return; const offset = inputs.length - (targetNode.widgets_values?.length ?? 0); @@ -983,13 +1072,12 @@ export class GroupNodeHandler { if (v == null) return; const widgetName = Object.values(map)[0]; - const widget = this.node.widgets.find(w => w.name === widgetName); - if(widget) { + const widget = this.node.widgets.find((w) => w.name === widgetName); + if (widget) { widget.value = v; } } - populateWidgets() { if (!this.node.widgets) return; @@ -1080,7 +1168,7 @@ export class GroupNodeHandler { } static getGroupData(node) { - return node.constructor?.nodeData?.[GROUP]; + return (node.nodeData ?? node.constructor?.nodeData)?.[GROUP]; } static isGroupNode(node) { @@ -1112,7 +1200,7 @@ export class GroupNodeHandler { } function addConvertToGroupOptions() { - function addOption(options, index) { + function addConvertOption(options, index) { const selected = Object.values(app.canvas.selected_nodes ?? {}); const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n)); options.splice(index + 1, null, { @@ -1124,12 +1212,25 @@ function addConvertToGroupOptions() { }); } + function addManageOption(options, index) { + const groups = app.graph.extra?.groupNodes; + const disabled = !groups || !Object.keys(groups).length; + options.splice(index + 1, null, { + content: `Manage Group Nodes`, + disabled, + callback: () => { + new ManageGroupDialog(app).show(); + }, + }); + } + // Add to canvas const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; LGraphCanvas.prototype.getCanvasMenuOptions = function () { const options = getCanvasMenuOptions.apply(this, arguments); const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length; - addOption(options, index); + addConvertOption(options, index); + addManageOption(options, index + 1); return options; }; @@ -1139,7 +1240,7 @@ function addConvertToGroupOptions() { const options = getNodeMenuOptions.apply(this, arguments); if (!GroupNodeHandler.isGroupNode(node)) { const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1; - addOption(options, index); + addConvertOption(options, index); } return options; }; @@ -1167,6 +1268,14 @@ const ext = { node[GROUP] = new GroupNodeHandler(node); } }, + async refreshComboInNodes(defs) { + // Re-register group nodes so new ones are created with the correct options + Object.assign(globalDefs, defs); + const nodes = app.graph.extra?.groupNodes; + if (nodes) { + await GroupNodeConfig.registerFromWorkflow(nodes, {}); + } + } }; app.registerExtension(ext); diff --git a/web/extensions/core/groupNodeManage.css b/web/extensions/core/groupNodeManage.css new file mode 100644 index 00000000000..5470ecb5e67 --- /dev/null +++ b/web/extensions/core/groupNodeManage.css @@ -0,0 +1,149 @@ +.comfy-group-manage { + background: var(--bg-color); + color: var(--fg-color); + padding: 0; + font-family: Arial, Helvetica, sans-serif; + border-color: black; + margin: 20vh auto; + max-height: 60vh; +} +.comfy-group-manage-outer { + max-height: 60vh; + min-width: 500px; + display: flex; + flex-direction: column; +} +.comfy-group-manage-outer > header { + display: flex; + align-items: center; + gap: 10px; + justify-content: space-between; + background: var(--comfy-menu-bg); + padding: 15px 20px; +} +.comfy-group-manage-outer > header select { + background: var(--comfy-input-bg); + border: 1px solid var(--border-color); + color: var(--input-text); + padding: 5px 10px; + border-radius: 5px; +} +.comfy-group-manage h2 { + margin: 0; + font-weight: normal; +} +.comfy-group-manage main { + display: flex; + overflow: hidden; +} +.comfy-group-manage .drag-handle { + font-weight: bold; +} +.comfy-group-manage-list { + border-right: 1px solid var(--comfy-menu-bg); +} +.comfy-group-manage-list ul { + margin: 40px 0 0; + padding: 0; + list-style: none; +} +.comfy-group-manage-list-items { + max-height: calc(100% - 40px); + overflow-y: scroll; + overflow-x: hidden; +} +.comfy-group-manage-list li { + display: flex; + padding: 10px 20px 10px 10px; + cursor: pointer; + align-items: center; + gap: 5px; +} +.comfy-group-manage-list div { + display: flex; + flex-direction: column; +} +.comfy-group-manage-list li:not(.selected):hover div { + text-decoration: underline; +} +.comfy-group-manage-list li.selected { + background: var(--border-color); +} +.comfy-group-manage-list li span { + opacity: 0.7; + font-size: smaller; +} +.comfy-group-manage-node { + flex: auto; + background: var(--border-color); + display: flex; + flex-direction: column; +} +.comfy-group-manage-node > div { + overflow: auto; +} +.comfy-group-manage-node header { + display: flex; + background: var(--bg-color); + height: 40px; +} +.comfy-group-manage-node header a { + text-align: center; + flex: auto; + border-right: 1px solid var(--comfy-menu-bg); + border-bottom: 1px solid var(--comfy-menu-bg); + padding: 10px; + cursor: pointer; + font-size: 15px; +} +.comfy-group-manage-node header a:last-child { + border-right: none; +} +.comfy-group-manage-node header a:not(.active):hover { + text-decoration: underline; +} +.comfy-group-manage-node header a.active { + background: var(--border-color); + border-bottom: none; +} +.comfy-group-manage-node-page { + display: none; + overflow: auto; +} +.comfy-group-manage-node-page.active { + display: block; +} +.comfy-group-manage-node-page div { + padding: 10px; + display: flex; + align-items: center; + gap: 10px; +} +.comfy-group-manage-node-page input { + border: none; + color: var(--input-text); + background: var(--comfy-input-bg); + padding: 5px 10px; +} +.comfy-group-manage-node-page input[type="text"] { + flex: auto; +} +.comfy-group-manage-node-page label { + display: flex; + gap: 5px; + align-items: center; +} +.comfy-group-manage footer { + border-top: 1px solid var(--comfy-menu-bg); + padding: 10px; + display: flex; + gap: 10px; +} +.comfy-group-manage footer button { + font-size: 14px; + padding: 5px 10px; + border-radius: 0; +} +.comfy-group-manage footer button:first-child { + margin-right: auto; +} diff --git a/web/extensions/core/groupNodeManage.js b/web/extensions/core/groupNodeManage.js new file mode 100644 index 00000000000..1ab33838688 --- /dev/null +++ b/web/extensions/core/groupNodeManage.js @@ -0,0 +1,422 @@ +import { $el, ComfyDialog } from "../../scripts/ui.js"; +import { DraggableList } from "../../scripts/ui/draggableList.js"; +import { addStylesheet } from "../../scripts/utils.js"; +import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; + +addStylesheet(import.meta.url); + +const ORDER = Symbol(); + +function merge(target, source) { + if (typeof target === "object" && typeof source === "object") { + for (const key in source) { + const sv = source[key]; + if (typeof sv === "object") { + let tv = target[key]; + if (!tv) tv = target[key] = {}; + merge(tv, source[key]); + } else { + target[key] = sv; + } + } + } + + return target; +} + +export class ManageGroupDialog extends ComfyDialog { + /** @type { Record<"Inputs" | "Outputs" | "Widgets", {tab: HTMLAnchorElement, page: HTMLElement}> } */ + tabs = {}; + /** @type { number | null | undefined } */ + selectedNodeIndex; + /** @type { keyof ManageGroupDialog["tabs"] } */ + selectedTab = "Inputs"; + /** @type { string | undefined } */ + selectedGroup; + + /** @type { Record>> } */ + modifications = {}; + + get selectedNodeInnerIndex() { + return +this.nodeItems[this.selectedNodeIndex].dataset.nodeindex; + } + + constructor(app) { + super(); + this.app = app; + this.element = $el("dialog.comfy-group-manage", { + parent: document.body, + }); + } + + changeTab(tab) { + this.tabs[this.selectedTab].tab.classList.remove("active"); + this.tabs[this.selectedTab].page.classList.remove("active"); + this.tabs[tab].tab.classList.add("active"); + this.tabs[tab].page.classList.add("active"); + this.selectedTab = tab; + } + + changeNode(index, force) { + if (!force && this.selectedNodeIndex === index) return; + + if (this.selectedNodeIndex != null) { + this.nodeItems[this.selectedNodeIndex].classList.remove("selected"); + } + this.nodeItems[index].classList.add("selected"); + this.selectedNodeIndex = index; + + if (!this.buildInputsPage() && this.selectedTab === "Inputs") { + this.changeTab("Widgets"); + } + if (!this.buildWidgetsPage() && this.selectedTab === "Widgets") { + this.changeTab("Outputs"); + } + if (!this.buildOutputsPage() && this.selectedTab === "Outputs") { + this.changeTab("Inputs"); + } + + this.changeTab(this.selectedTab); + } + + getGroupData() { + this.groupNodeType = LiteGraph.registered_node_types["workflow/" + this.selectedGroup]; + this.groupNodeDef = this.groupNodeType.nodeData; + this.groupData = GroupNodeHandler.getGroupData(this.groupNodeType); + } + + changeGroup(group, reset = true) { + this.selectedGroup = group; + this.getGroupData(); + + const nodes = this.groupData.nodeData.nodes; + this.nodeItems = nodes.map((n, i) => + $el( + "li.draggable-item", + { + dataset: { + nodeindex: n.index + "", + }, + onclick: () => { + this.changeNode(i); + }, + }, + [ + $el("span.drag-handle"), + $el( + "div", + { + textContent: n.title ?? n.type, + }, + n.title + ? $el("span", { + textContent: n.type, + }) + : [] + ), + ] + ) + ); + + this.innerNodesList.replaceChildren(...this.nodeItems); + + if (reset) { + this.selectedNodeIndex = null; + this.changeNode(0); + } else { + const items = this.draggable.getAllItems(); + let index = items.findIndex(item => item.classList.contains("selected")); + if(index === -1) index = this.selectedNodeIndex; + this.changeNode(index, true); + } + + const ordered = [...nodes]; + this.draggable?.dispose(); + this.draggable = new DraggableList(this.innerNodesList, "li"); + this.draggable.addEventListener("dragend", ({ detail: { oldPosition, newPosition } }) => { + if (oldPosition === newPosition) return; + ordered.splice(newPosition, 0, ordered.splice(oldPosition, 1)[0]); + for (let i = 0; i < ordered.length; i++) { + this.storeModification({ nodeIndex: ordered[i].index, section: ORDER, prop: "order", value: i }); + } + }); + } + + storeModification({ nodeIndex, section, prop, value }) { + const groupMod = (this.modifications[this.selectedGroup] ??= {}); + const nodesMod = (groupMod.nodes ??= {}); + const nodeMod = (nodesMod[nodeIndex ?? this.selectedNodeInnerIndex] ??= {}); + const typeMod = (nodeMod[section] ??= {}); + if (typeof value === "object") { + const objMod = (typeMod[prop] ??= {}); + Object.assign(objMod, value); + } else { + typeMod[prop] = value; + } + } + + getEditElement(section, prop, value, placeholder, checked, checkable = true) { + if (value === placeholder) value = ""; + + const mods = this.modifications[this.selectedGroup]?.nodes?.[this.selectedNodeInnerIndex]?.[section]?.[prop]; + if (mods) { + if (mods.name != null) { + value = mods.name; + } + if (mods.visible != null) { + checked = mods.visible; + } + } + + return $el("div", [ + $el("input", { + value, + placeholder, + type: "text", + onchange: (e) => { + this.storeModification({ section, prop, value: { name: e.target.value } }); + }, + }), + $el("label", { textContent: "Visible" }, [ + $el("input", { + type: "checkbox", + checked, + disabled: !checkable, + onchange: (e) => { + this.storeModification({ section, prop, value: { visible: !!e.target.checked } }); + }, + }), + ]), + ]); + } + + buildWidgetsPage() { + const widgets = this.groupData.oldToNewWidgetMap[this.selectedNodeInnerIndex]; + const items = Object.keys(widgets ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.widgetsPage.replaceChildren( + ...items.map((oldName) => { + return this.getEditElement("input", oldName, widgets[oldName], oldName, config?.[oldName]?.visible !== false); + }) + ); + return !!items.length; + } + + buildInputsPage() { + const inputs = this.groupData.nodeInputs[this.selectedNodeInnerIndex]; + const items = Object.keys(inputs ?? {}); + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.input; + this.inputsPage.replaceChildren( + ...items + .map((oldName) => { + let value = inputs[oldName]; + if (!value) { + return; + } + + return this.getEditElement("input", oldName, value, oldName, config?.[oldName]?.visible !== false); + }) + .filter(Boolean) + ); + return !!items.length; + } + + buildOutputsPage() { + const nodes = this.groupData.nodeData.nodes; + const innerNodeDef = this.groupData.getNodeDef(nodes[this.selectedNodeInnerIndex]); + const outputs = innerNodeDef?.output ?? []; + const groupOutputs = this.groupData.oldToNewOutputMap[this.selectedNodeInnerIndex]; + + const type = app.graph.extra.groupNodes[this.selectedGroup]; + const config = type.config?.[this.selectedNodeInnerIndex]?.output; + const node = this.groupData.nodeData.nodes[this.selectedNodeInnerIndex]; + const checkable = node.type !== "PrimitiveNode"; + this.outputsPage.replaceChildren( + ...outputs + .map((type, slot) => { + const groupOutputIndex = groupOutputs?.[slot]; + const oldName = innerNodeDef.output_name?.[slot] ?? type; + let value = config?.[slot]?.name; + const visible = config?.[slot]?.visible || groupOutputIndex != null; + if (!value || value === oldName) { + value = ""; + } + return this.getEditElement("output", slot, value, oldName, visible, checkable); + }) + .filter(Boolean) + ); + return !!outputs.length; + } + + show(type) { + const groupNodes = Object.keys(app.graph.extra?.groupNodes ?? {}).sort((a, b) => a.localeCompare(b)); + + this.innerNodesList = $el("ul.comfy-group-manage-list-items"); + this.widgetsPage = $el("section.comfy-group-manage-node-page"); + this.inputsPage = $el("section.comfy-group-manage-node-page"); + this.outputsPage = $el("section.comfy-group-manage-node-page"); + const pages = $el("div", [this.widgetsPage, this.inputsPage, this.outputsPage]); + + this.tabs = [ + ["Inputs", this.inputsPage], + ["Widgets", this.widgetsPage], + ["Outputs", this.outputsPage], + ].reduce((p, [name, page]) => { + p[name] = { + tab: $el("a", { + onclick: () => { + this.changeTab(name); + }, + textContent: name, + }), + page, + }; + return p; + }, {}); + + const outer = $el("div.comfy-group-manage-outer", [ + $el("header", [ + $el("h2", "Group Nodes"), + $el( + "select", + { + onchange: (e) => { + this.changeGroup(e.target.value); + }, + }, + groupNodes.map((g) => + $el("option", { + textContent: g, + selected: "workflow/" + g === type, + value: g, + }) + ) + ), + ]), + $el("main", [ + $el("section.comfy-group-manage-list", this.innerNodesList), + $el("section.comfy-group-manage-node", [ + $el( + "header", + Object.values(this.tabs).map((t) => t.tab) + ), + pages, + ]), + ]), + $el("footer", [ + $el( + "button.comfy-btn", + { + onclick: (e) => { + const node = app.graph._nodes.find((n) => n.type === "workflow/" + this.selectedGroup); + if (node) { + alert("This group node is in use in the current workflow, please first remove these."); + return; + } + if (confirm(`Are you sure you want to remove the node: "${this.selectedGroup}"`)) { + delete app.graph.extra.groupNodes[this.selectedGroup]; + LiteGraph.unregisterNodeType("workflow/" + this.selectedGroup); + } + this.show(); + }, + }, + "Delete Group Node" + ), + $el( + "button.comfy-btn", + { + onclick: async () => { + let nodesByType; + let recreateNodes = []; + const types = {}; + for (const g in this.modifications) { + const type = app.graph.extra.groupNodes[g]; + let config = (type.config ??= {}); + + let nodeMods = this.modifications[g]?.nodes; + if (nodeMods) { + const keys = Object.keys(nodeMods); + if (nodeMods[keys[0]][ORDER]) { + // If any node is reordered, they will all need sequencing + const orderedNodes = []; + const orderedMods = {}; + const orderedConfig = {}; + + for (const n of keys) { + const order = nodeMods[n][ORDER].order; + orderedNodes[order] = type.nodes[+n]; + orderedMods[order] = nodeMods[n]; + orderedNodes[order].index = order; + } + + // Rewrite links + for (const l of type.links) { + if (l[0] != null) l[0] = type.nodes[l[0]].index; + if (l[2] != null) l[2] = type.nodes[l[2]].index; + } + + // Rewrite externals + if (type.external) { + for (const ext of type.external) { + ext[0] = type.nodes[ext[0]]; + } + } + + // Rewrite modifications + for (const id of keys) { + if (config[id]) { + orderedConfig[type.nodes[id].index] = config[id]; + } + delete config[id]; + } + + type.nodes = orderedNodes; + nodeMods = orderedMods; + type.config = config = orderedConfig; + } + + merge(config, nodeMods); + } + + types[g] = type; + + if (!nodesByType) { + nodesByType = app.graph._nodes.reduce((p, n) => { + p[n.type] ??= []; + p[n.type].push(n); + return p; + }, {}); + } + + const nodes = nodesByType["workflow/" + g]; + if (nodes) recreateNodes.push(...nodes); + } + + await GroupNodeConfig.registerFromWorkflow(types, {}); + + for (const node of recreateNodes) { + node.recreate(); + } + + this.modifications = {}; + this.app.graph.setDirtyCanvas(true, true); + this.changeGroup(this.selectedGroup, false); + }, + }, + "Save" + ), + $el("button.comfy-btn", { onclick: () => this.element.close() }, "Close"), + ]), + ]); + + this.element.replaceChildren(outer); + this.changeGroup(type ? groupNodes.find((g) => "workflow/" + g === type) : groupNodes[0]); + this.element.showModal(); + + this.element.addEventListener("close", () => { + this.draggable?.dispose(); + }); + } +} \ No newline at end of file diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index bb2f16d42b5..4f69ac7607c 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -62,7 +62,7 @@ async function uploadMask(filepath, formData) { ClipspaceDialog.invalidatePreview(); } -function prepare_mask(image, maskCanvas, maskCtx) { +function prepare_mask(image, maskCanvas, maskCtx, maskColor) { // paste mask data into alpha channel maskCtx.drawImage(image, 0, 0, maskCanvas.width, maskCanvas.height); const maskData = maskCtx.getImageData(0, 0, maskCanvas.width, maskCanvas.height); @@ -74,9 +74,9 @@ function prepare_mask(image, maskCanvas, maskCtx) { else maskData.data[i+3] = 255; - maskData.data[i] = 0; - maskData.data[i+1] = 0; - maskData.data[i+2] = 0; + maskData.data[i] = maskColor.r; + maskData.data[i+1] = maskColor.g; + maskData.data[i+2] = maskColor.b; } maskCtx.globalCompositeOperation = 'source-over'; @@ -110,6 +110,7 @@ class MaskEditorDialog extends ComfyDialog { createButton(name, callback) { var button = document.createElement("button"); + button.style.pointerEvents = "auto"; button.innerText = name; button.addEventListener("click", callback); return button; @@ -146,6 +147,7 @@ class MaskEditorDialog extends ComfyDialog { divElement.style.display = "flex"; divElement.style.position = "relative"; divElement.style.top = "2px"; + divElement.style.pointerEvents = "auto"; self.brush_slider_input = document.createElement('input'); self.brush_slider_input.setAttribute('type', 'range'); self.brush_slider_input.setAttribute('min', '1'); @@ -173,6 +175,7 @@ class MaskEditorDialog extends ComfyDialog { bottom_panel.style.left = "20px"; bottom_panel.style.right = "20px"; bottom_panel.style.height = "50px"; + bottom_panel.style.pointerEvents = "none"; var brush = document.createElement("div"); brush.id = "brush"; @@ -191,14 +194,29 @@ class MaskEditorDialog extends ComfyDialog { this.element.appendChild(bottom_panel); document.body.appendChild(brush); + var clearButton = this.createLeftButton("Clear", () => { + self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); + }); + this.brush_size_slider = this.createLeftSlider(self, "Thickness", (event) => { self.brush_size = event.target.value; self.updateBrushPreview(self, null, null); }); - var clearButton = this.createLeftButton("Clear", - () => { - self.maskCtx.clearRect(0, 0, self.maskCanvas.width, self.maskCanvas.height); - }); + + this.colorButton = this.createLeftButton(this.getColorButtonText(), () => { + if (self.brush_color_mode === "black") { + self.brush_color_mode = "white"; + } + else if (self.brush_color_mode === "white") { + self.brush_color_mode = "negative"; + } + else { + self.brush_color_mode = "black"; + } + + self.updateWhenBrushColorModeChanged(); + }); + var cancelButton = this.createRightButton("Cancel", () => { document.removeEventListener("mouseup", MaskEditorDialog.handleMouseUp); document.removeEventListener("keydown", MaskEditorDialog.handleKeyDown); @@ -219,6 +237,7 @@ class MaskEditorDialog extends ComfyDialog { bottom_panel.appendChild(this.saveButton); bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(this.brush_size_slider); + bottom_panel.appendChild(this.colorButton); imgCanvas.style.position = "absolute"; maskCanvas.style.position = "absolute"; @@ -228,6 +247,10 @@ class MaskEditorDialog extends ComfyDialog { maskCanvas.style.top = imgCanvas.style.top; maskCanvas.style.left = imgCanvas.style.left; + + const maskCanvasStyle = this.getMaskCanvasStyle(); + maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode; + maskCanvas.style.opacity = maskCanvasStyle.opacity; } async show() { @@ -313,7 +336,7 @@ class MaskEditorDialog extends ComfyDialog { let maskCtx = this.maskCanvas.getContext('2d', {willReadFrequently: true }); imgCtx.drawImage(orig_image, 0, 0, orig_image.width, orig_image.height); - prepare_mask(mask_image, this.maskCanvas, maskCtx); + prepare_mask(mask_image, this.maskCanvas, maskCtx, this.getMaskColor()); } async setImages(imgCanvas) { @@ -439,7 +462,84 @@ class MaskEditorDialog extends ComfyDialog { } } + getMaskCanvasStyle() { + if (this.brush_color_mode === "negative") { + return { + mixBlendMode: "difference", + opacity: "1", + }; + } + else { + return { + mixBlendMode: "initial", + opacity: "0.7", + }; + } + } + + getMaskColor() { + if (this.brush_color_mode === "black") { + return { r: 0, g: 0, b: 0 }; + } + if (this.brush_color_mode === "white") { + return { r: 255, g: 255, b: 255 }; + } + if (this.brush_color_mode === "negative") { + // negative effect only works with white color + return { r: 255, g: 255, b: 255 }; + } + + return { r: 0, g: 0, b: 0 }; + } + + getMaskFillStyle() { + const maskColor = this.getMaskColor(); + + return "rgb(" + maskColor.r + "," + maskColor.g + "," + maskColor.b + ")"; + } + + getColorButtonText() { + let colorCaption = "unknown"; + + if (this.brush_color_mode === "black") { + colorCaption = "black"; + } + else if (this.brush_color_mode === "white") { + colorCaption = "white"; + } + else if (this.brush_color_mode === "negative") { + colorCaption = "negative"; + } + + return "Color: " + colorCaption; + } + + updateWhenBrushColorModeChanged() { + this.colorButton.innerText = this.getColorButtonText(); + + // update mask canvas css styles + + const maskCanvasStyle = this.getMaskCanvasStyle(); + this.maskCanvas.style.mixBlendMode = maskCanvasStyle.mixBlendMode; + this.maskCanvas.style.opacity = maskCanvasStyle.opacity; + + // update mask canvas rgb colors + + const maskColor = this.getMaskColor(); + + const maskData = this.maskCtx.getImageData(0, 0, this.maskCanvas.width, this.maskCanvas.height); + + for (let i = 0; i < maskData.data.length; i += 4) { + maskData.data[i] = maskColor.r; + maskData.data[i+1] = maskColor.g; + maskData.data[i+2] = maskColor.b; + } + + this.maskCtx.putImageData(maskData, 0, 0); + } + brush_size = 10; + brush_color_mode = "black"; drawing_mode = false; lastx = -1; lasty = -1; @@ -518,6 +618,19 @@ class MaskEditorDialog extends ComfyDialog { event.preventDefault(); self.pan_move(self, event); } + + let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1; + + if(event.shiftKey && left_button_down) { + self.drawing_mode = false; + + const y = event.clientY; + let delta = (self.zoom_lasty - y)*0.005; + self.zoom_ratio = Math.max(Math.min(10.0, self.last_zoom_ratio - delta), 0.2); + + this.invalidatePanZoom(); + return; + } } pan_move(self, event) { @@ -535,7 +648,7 @@ class MaskEditorDialog extends ComfyDialog { } draw_move(self, event) { - if(event.ctrlKey) { + if(event.ctrlKey || event.shiftKey) { return; } @@ -546,7 +659,10 @@ class MaskEditorDialog extends ComfyDialog { self.updateBrushPreview(self); - if (window.TouchEvent && event instanceof TouchEvent || event.buttons == 1) { + let left_button_down = window.TouchEvent && event instanceof TouchEvent || event.buttons == 1; + let right_button_down = [2, 5, 32].includes(event.buttons); + + if (!event.altKey && left_button_down) { var diff = performance.now() - self.lasttime; const maskRect = self.maskCanvas.getBoundingClientRect(); @@ -581,7 +697,7 @@ class MaskEditorDialog extends ComfyDialog { if(diff > 20 && !this.drawing_mode) requestAnimationFrame(() => { self.maskCtx.beginPath(); - self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); self.maskCtx.fill(); @@ -591,7 +707,7 @@ class MaskEditorDialog extends ComfyDialog { else requestAnimationFrame(() => { self.maskCtx.beginPath(); - self.maskCtx.fillStyle = "rgb(0,0,0)"; + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; var dx = x - self.lastx; @@ -613,7 +729,7 @@ class MaskEditorDialog extends ComfyDialog { self.lasttime = performance.now(); } - else if(event.buttons == 2 || event.buttons == 5 || event.buttons == 32) { + else if((event.altKey && left_button_down) || right_button_down) { const maskRect = self.maskCanvas.getBoundingClientRect(); const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; @@ -687,13 +803,20 @@ class MaskEditorDialog extends ComfyDialog { self.drawing_mode = true; event.preventDefault(); + + if(event.shiftKey) { + self.zoom_lasty = event.clientY; + self.last_zoom_ratio = self.zoom_ratio; + return; + } + const maskRect = self.maskCanvas.getBoundingClientRect(); const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; self.maskCtx.beginPath(); - if (event.button == 0) { - self.maskCtx.fillStyle = "rgb(0,0,0)"; + if (!event.altKey && event.button == 0) { + self.maskCtx.fillStyle = this.getMaskFillStyle(); self.maskCtx.globalCompositeOperation = "source-over"; } else { self.maskCtx.globalCompositeOperation = "destination-out"; diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js index bc9a108644a..9350ba6549c 100644 --- a/web/extensions/core/nodeTemplates.js +++ b/web/extensions/core/nodeTemplates.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; import { ComfyDialog, $el } from "../../scripts/ui.js"; import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; @@ -20,16 +21,20 @@ import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; // Open the manage dialog and Drag and drop elements using the "Name:" label as handle const id = "Comfy.NodeTemplates"; +const file = "comfy.templates.json"; class ManageTemplates extends ComfyDialog { constructor() { super(); + this.load().then((v) => { + this.templates = v; + }); + this.element.classList.add("comfy-manage-templates"); - this.templates = this.load(); this.draggedEl = null; this.saveVisualCue = null; this.emptyImg = new Image(); - this.emptyImg.src = ''; + this.emptyImg.src = ""; this.importInput = $el("input", { type: "file", @@ -67,17 +72,50 @@ class ManageTemplates extends ComfyDialog { return btns; } - load() { - const templates = localStorage.getItem(id); - if (templates) { - return JSON.parse(templates); + async load() { + let templates = []; + if (app.storageLocation === "server") { + if (app.isNewUserSession) { + // New user so migrate existing templates + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } + await api.storeUserData(file, json, { stringify: false }); + } else { + const res = await api.getUserData(file); + if (res.status === 200) { + try { + templates = await res.json(); + } catch (error) { + } + } else if (res.status !== 404) { + console.error(res.status + " " + res.statusText); + } + } } else { - return []; + const json = localStorage.getItem(id); + if (json) { + templates = JSON.parse(json); + } } + + return templates ?? []; } - store() { - localStorage.setItem(id, JSON.stringify(this.templates)); + async store() { + if(app.storageLocation === "server") { + const templates = JSON.stringify(this.templates, undefined, 4); + localStorage.setItem(id, templates); // Backwards compatibility + try { + await api.storeUserData(file, templates, { stringify: false }); + } catch (error) { + console.error(error); + alert(error.message); + } + } else { + localStorage.setItem(id, JSON.stringify(this.templates)); + } } async importAll() { @@ -85,14 +123,14 @@ class ManageTemplates extends ComfyDialog { if (file.type === "application/json" || file.name.endsWith(".json")) { const reader = new FileReader(); reader.onload = async () => { - var importFile = JSON.parse(reader.result); - if (importFile && importFile?.templates) { + const importFile = JSON.parse(reader.result); + if (importFile?.templates) { for (const template of importFile.templates) { if (template?.name && template?.data) { this.templates.push(template); } } - this.store(); + await this.store(); } }; await reader.readAsText(file); @@ -159,7 +197,7 @@ class ManageTemplates extends ComfyDialog { e.currentTarget.style.border = "1px dashed transparent"; e.currentTarget.removeAttribute("draggable"); - // rearrange the elements in the localStorage + // rearrange the elements this.element.querySelectorAll('.tempateManagerRow').forEach((el,i) => { var prev_i = el.dataset.id; diff --git a/web/extensions/core/simpleTouchSupport.js b/web/extensions/core/simpleTouchSupport.js new file mode 100644 index 00000000000..041fc2c4ca9 --- /dev/null +++ b/web/extensions/core/simpleTouchSupport.js @@ -0,0 +1,102 @@ +import { app } from "../../scripts/app.js"; + +let touchZooming; +let touchCount = 0; + +app.registerExtension({ + name: "Comfy.SimpleTouchSupport", + setup() { + let zoomPos; + let touchTime; + let lastTouch; + + function getMultiTouchPos(e) { + return Math.hypot(e.touches[0].clientX - e.touches[1].clientX, e.touches[0].clientY - e.touches[1].clientY); + } + + app.canvasEl.addEventListener( + "touchstart", + (e) => { + touchCount++; + lastTouch = null; + if (e.touches?.length === 1) { + // Store start time for press+hold for context menu + touchTime = new Date(); + lastTouch = e.touches[0]; + } else { + touchTime = null; + if (e.touches?.length === 2) { + // Store center pos for zoom + zoomPos = getMultiTouchPos(e); + app.canvas.pointer_is_down = false; + } + } + }, + true + ); + + app.canvasEl.addEventListener("touchend", (e) => { + touchZooming = false; + touchCount = e.touches?.length ?? touchCount - 1; + if (touchTime && !e.touches?.length) { + if (new Date() - touchTime > 600) { + try { + // hack to get litegraph to use this event + e.constructor = CustomEvent; + } catch (error) {} + e.clientX = lastTouch.clientX; + e.clientY = lastTouch.clientY; + + app.canvas.pointer_is_down = true; + app.canvas._mousedown_callback(e); + } + touchTime = null; + } + }); + + app.canvasEl.addEventListener( + "touchmove", + (e) => { + touchTime = null; + if (e.touches?.length === 2) { + app.canvas.pointer_is_down = false; + touchZooming = true; + LiteGraph.closeAllContextMenus(); + app.canvas.search_box?.close(); + const newZoomPos = getMultiTouchPos(e); + + const midX = (e.touches[0].clientX + e.touches[1].clientX) / 2; + const midY = (e.touches[0].clientY + e.touches[1].clientY) / 2; + + let scale = app.canvas.ds.scale; + const diff = zoomPos - newZoomPos; + if (diff > 0.5) { + scale *= 1 / 1.07; + } else if (diff < -0.5) { + scale *= 1.07; + } + app.canvas.ds.changeScale(scale, [midX, midY]); + app.canvas.setDirty(true, true); + zoomPos = newZoomPos; + } + }, + true + ); + }, +}); + +const processMouseDown = LGraphCanvas.prototype.processMouseDown; +LGraphCanvas.prototype.processMouseDown = function (e) { + if (touchZooming || touchCount) { + return; + } + return processMouseDown.apply(this, arguments); +}; + +const processMouseMove = LGraphCanvas.prototype.processMouseMove; +LGraphCanvas.prototype.processMouseMove = function (e) { + if (touchZooming || touchCount > 1) { + return; + } + return processMouseMove.apply(this, arguments); +}; diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index 3cb137520f4..900eed2a7cd 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -1,4 +1,5 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js" const MAX_HISTORY = 50; @@ -15,6 +16,7 @@ function checkState() { } activeState = clone(currentState); redo.length = 0; + api.dispatchEvent(new CustomEvent("graphChanged", { detail: activeState })); } } @@ -92,7 +94,7 @@ const undoRedo = async (e) => { }; const bindInput = (activeEl) => { - if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { + if (activeEl && activeEl.tagName !== "CANVAS" && activeEl.tagName !== "BODY") { for (const evt of ["change", "input", "blur"]) { if (`on${evt}` in activeEl) { const listener = () => { @@ -106,15 +108,23 @@ const bindInput = (activeEl) => { } }; +let keyIgnored = false; window.addEventListener( "keydown", (e) => { requestAnimationFrame(async () => { - const activeEl = document.activeElement; - if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { - // Ignore events on inputs, they have their native history - return; + let activeEl; + // If we are auto queue in change mode then we do want to trigger on inputs + if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") { + activeEl = document.activeElement; + if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { + // Ignore events on inputs, they have their native history + return; + } } + + keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta"; + if (keyIgnored) return; // Check if this is a ctrl+z ctrl+y if (await undoRedo(e)) return; @@ -127,11 +137,23 @@ window.addEventListener( true ); +window.addEventListener("keyup", (e) => { + if (keyIgnored) { + keyIgnored = false; + checkState(); + } +}); + // Handle clicking DOM elements (e.g. widgets) window.addEventListener("mouseup", () => { checkState(); }); +// Handle prompt queue event for dynamic widget changes +api.addEventListener("promptQueued", () => { + checkState(); +}); + // Handle litegraph clicks const processMouseUp = LGraphCanvas.prototype.processMouseUp; LGraphCanvas.prototype.processMouseUp = function (e) { @@ -145,3 +167,11 @@ LGraphCanvas.prototype.processMouseDown = function (e) { checkState(); return v; }; + +// Handle litegraph context menu for COMBO widgets +const close = LiteGraph.ContextMenu.prototype.close; +LiteGraph.ContextMenu.prototype.close = function(e) { + const v = close.apply(this, arguments); + checkState(); + return v; +} \ No newline at end of file diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 3f1c1f8c126..23f51d812b4 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -22,6 +22,7 @@ function isConvertableWidget(widget, config) { } function hideWidget(node, widget, suffix = "") { + if (widget.type?.startsWith(CONVERTED_TYPE)) return; widget.origType = widget.type; widget.origComputeSize = widget.computeSize; widget.origSerializeValue = widget.serializeValue; @@ -260,6 +261,12 @@ app.registerExtension({ async beforeRegisterNodeDef(nodeType, nodeData, app) { // Add menu options to conver to/from widgets const origGetExtraMenuOptions = nodeType.prototype.getExtraMenuOptions; + nodeType.prototype.convertWidgetToInput = function (widget) { + const config = getConfig.call(this, widget.name) ?? [widget.type, widget.options || {}]; + if (!isConvertableWidget(widget, config)) return false; + convertToInput(this, widget, config); + return true; + }; nodeType.prototype.getExtraMenuOptions = function (_, options) { const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : undefined; diff --git a/web/index.html b/web/index.html index 41bc246c090..094db9d1529 100644 --- a/web/index.html +++ b/web/index.html @@ -16,5 +16,33 @@ window.graph = app.graph; - + + + diff --git a/web/jsconfig.json b/web/jsconfig.json index 57403d8cf2b..b65fa2746da 100644 --- a/web/jsconfig.json +++ b/web/jsconfig.json @@ -3,7 +3,8 @@ "baseUrl": ".", "paths": { "/*": ["./*"] - } + }, + "lib": ["DOM", "ES2022"] }, "include": ["."] } diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 434c4a83bf1..4ff05ae8130 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -11496,7 +11496,7 @@ LGraphNode.prototype.executeAction = function(action) } timeout_close = setTimeout(function() { dialog.close(); - }, 500); + }, typeof options.hide_on_mouse_leave === "number" ? options.hide_on_mouse_leave : 500); }); // if filtering, check focus changed to comboboxes and prevent closing if (options.do_type_filter){ @@ -11549,7 +11549,7 @@ LGraphNode.prototype.executeAction = function(action) dialog.close(); } else if (e.keyCode == 13) { if (selected) { - select(selected.innerHTML); + select(unescape(selected.dataset["type"])); } else if (first) { select(first); } else { @@ -11910,7 +11910,7 @@ LGraphNode.prototype.executeAction = function(action) var ctor = LiteGraph.registered_node_types[ type ]; if(filter && ctor.filter != filter ) return false; - if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1) + if ((!options.show_all_if_empty || str) && type.toLowerCase().indexOf(str) === -1 && (!ctor.title || ctor.title.toLowerCase().indexOf(str) === -1)) return false; // filter by slot IN, OUT types @@ -11964,7 +11964,18 @@ LGraphNode.prototype.executeAction = function(action) if (!first) { first = type; } - help.innerText = type; + + const nodeType = LiteGraph.registered_node_types[type]; + if (nodeType?.title) { + help.innerText = nodeType?.title; + const typeEl = document.createElement("span"); + typeEl.className = "litegraph lite-search-item-type"; + typeEl.textContent = type; + help.append(typeEl); + } else { + help.innerText = type; + } + help.dataset["type"] = escape(type); help.className = "litegraph lite-search-item"; if (className) { diff --git a/web/lib/litegraph.css b/web/lib/litegraph.css index 918858f415d..5524e24bacb 100644 --- a/web/lib/litegraph.css +++ b/web/lib/litegraph.css @@ -184,6 +184,7 @@ color: white; padding-left: 10px; margin-right: 5px; + max-width: 300px; } .litegraph.litesearchbox .name { @@ -227,6 +228,18 @@ color: black; } +.litegraph.lite-search-item-type { + display: inline-block; + background: rgba(0,0,0,0.2); + margin-left: 5px; + font-size: 14px; + padding: 2px 5px; + position: relative; + top: -2px; + opacity: 0.8; + border-radius: 4px; + } + /* DIALOGs ******/ .litegraph .dialog { diff --git a/web/scripts/api.js b/web/scripts/api.js index c63fdf2c03c..78824291ddc 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -5,6 +5,7 @@ class ComfyApi extends EventTarget { super(); this.api_host = location.host; this.api_base = location.pathname.split('/').slice(0, -1).join('/'); + this.initialClientId = sessionStorage.getItem("clientId"); } apiURL(route) { @@ -12,6 +13,13 @@ class ComfyApi extends EventTarget { } fetchApi(route, options) { + if (!options) { + options = {}; + } + if (!options.headers) { + options.headers = {}; + } + options.headers["Comfy-User"] = this.user; return fetch(this.apiURL(route), options); } @@ -111,7 +119,8 @@ class ComfyApi extends EventTarget { case "status": if (msg.data.sid) { this.clientId = msg.data.sid; - window.name = this.clientId; + window.name = this.clientId; // use window name so it isnt reused when duplicating tabs + sessionStorage.setItem("clientId", this.clientId); // store in session storage so duplicate tab can load correct workflow } this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); break; @@ -316,6 +325,99 @@ class ComfyApi extends EventTarget { async interrupt() { await this.#postItem("interrupt", null); } + + /** + * Gets user configuration data and where data should be stored + * @returns { Promise<{ storage: "server" | "browser", users?: Promise, migrated?: boolean }> } + */ + async getUserConfig() { + return (await this.fetchApi("/users")).json(); + } + + /** + * Creates a new user + * @param { string } username + * @returns The fetch response + */ + createUser(username) { + return this.fetchApi("/users", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ username }), + }); + } + + /** + * Gets all setting values for the current user + * @returns { Promise } A dictionary of id -> value + */ + async getSettings() { + return (await this.fetchApi("/settings")).json(); + } + + /** + * Gets a setting for the current user + * @param { string } id The id of the setting to fetch + * @returns { Promise } The setting value + */ + async getSetting(id) { + return (await this.fetchApi(`/settings/${encodeURIComponent(id)}`)).json(); + } + + /** + * Stores a dictionary of settings for the current user + * @param { Record } settings Dictionary of setting id -> value to save + * @returns { Promise } + */ + async storeSettings(settings) { + return this.fetchApi(`/settings`, { + method: "POST", + body: JSON.stringify(settings) + }); + } + + /** + * Stores a setting for the current user + * @param { string } id The id of the setting to update + * @param { unknown } value The value of the setting + * @returns { Promise } + */ + async storeSetting(id, value) { + return this.fetchApi(`/settings/${encodeURIComponent(id)}`, { + method: "POST", + body: JSON.stringify(value) + }); + } + + /** + * Gets a user data file for the current user + * @param { string } file The name of the userdata file to load + * @param { RequestInit } [options] + * @returns { Promise } The fetch response object + */ + async getUserData(file, options) { + return this.fetchApi(`/userdata/${encodeURIComponent(file)}`, options); + } + + /** + * Stores a user data file for the current user + * @param { string } file The name of the userdata file to save + * @param { unknown } data The data to save to the file + * @param { RequestInit & { stringify?: boolean, throwOnError?: boolean } } [options] + * @returns { Promise } + */ + async storeUserData(file, data, options = { stringify: true, throwOnError: true }) { + const resp = await this.fetchApi(`/userdata/${encodeURIComponent(file)}`, { + method: "POST", + body: options?.stringify ? JSON.stringify(data) : data, + ...options, + }); + if (resp.status !== 200) { + throw new Error(`Error storing user data file '${file}': ${resp.status} ${(await resp).statusText}`); + } + } } export const api = new ComfyApi(); diff --git a/web/scripts/app.js b/web/scripts/app.js index 8eee959869d..c5d8e01ebee 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1,5 +1,5 @@ import { ComfyLogging } from "./logging.js"; -import { ComfyWidgets } from "./widgets.js"; +import { ComfyWidgets, initWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; @@ -8,21 +8,22 @@ import { addDomClippingSetting } from "./domWidget.js"; import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js" import { getUserId } from "./utils.js"; import { getWorkflow } from "./utils.js"; + export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview" function sanitizeNodeName(string) { let entityMap = { - '&': '', - '<': '', - '>': '', - '"': '', - "'": '', - '`': '', - '=': '' + '&': '', + '<': '', + '>': '', + '"': '', + "'": '', + '`': '', + '=': '' }; - return String(string).replace(/[&<>"'`=]/g, function fromEntityMap (s) { + return String(string).replace(/[&<>"'`=]/g, function fromEntityMap(s) { return entityMap[s]; }); } @@ -83,7 +84,7 @@ export class ComfyApp { getPreviewFormatParam() { let preview_format = this.ui.settings.getSettingValue("Comfy.PreviewFormat"); - if(preview_format) + if (preview_format) return `&preview=${preview_format}`; else return ""; @@ -98,7 +99,7 @@ export class ComfyApp { } static onClipspaceEditorSave() { - if(ComfyApp.clipspace_return_node) { + if (ComfyApp.clipspace_return_node) { ComfyApp.pasteFromClipspace(ComfyApp.clipspace_return_node); } } @@ -109,13 +110,13 @@ export class ComfyApp { static copyToClipspace(node) { var widgets = null; - if(node.widgets) { + if (node.widgets) { widgets = node.widgets.map(({ type, name, value }) => ({ type, name, value })); } var imgs = undefined; var orig_imgs = undefined; - if(node.imgs != undefined) { + if (node.imgs != undefined) { imgs = []; orig_imgs = []; @@ -127,7 +128,7 @@ export class ComfyApp { } var selectedIndex = 0; - if(node.imageIndex) { + if (node.imageIndex) { selectedIndex = node.imageIndex; } @@ -142,30 +143,30 @@ export class ComfyApp { ComfyApp.clipspace_return_node = null; - if(ComfyApp.clipspace_invalidate_handler) { + if (ComfyApp.clipspace_invalidate_handler) { ComfyApp.clipspace_invalidate_handler(); } } static pasteFromClipspace(node) { - if(ComfyApp.clipspace) { + if (ComfyApp.clipspace) { // image paste - if(ComfyApp.clipspace.imgs && node.imgs) { - if(node.images && ComfyApp.clipspace.images) { - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + if (ComfyApp.clipspace.imgs && node.imgs) { + if (node.images && ComfyApp.clipspace.images) { + if (ComfyApp.clipspace['img_paste_mode'] == 'selected') { node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; } else { node.images = ComfyApp.clipspace.images; } - if(app.nodeOutputs[node.id + ""]) + if (app.nodeOutputs[node.id + ""]) app.nodeOutputs[node.id + ""].images = node.images; } - if(ComfyApp.clipspace.imgs) { + if (ComfyApp.clipspace.imgs) { // deep-copy to cut link with clipspace - if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { + if (ComfyApp.clipspace['img_paste_mode'] == 'selected') { const img = new Image(); img.src = ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src; node.imgs = [img]; @@ -173,7 +174,7 @@ export class ComfyApp { } else { const imgs = []; - for(let i=0; i obj.name === 'image'); - if(index >= 0) { - if(node.widgets[index].type != 'image' && typeof node.widgets[index].value == "string" && clip_image.filename) { - node.widgets[index].value = (clip_image.subfolder?clip_image.subfolder+'/':'') + clip_image.filename + (clip_image.type?` [${clip_image.type}]`:''); + if (index >= 0) { + if (node.widgets[index].type != 'image' && typeof node.widgets[index].value == "string" && clip_image.filename) { + node.widgets[index].value = (clip_image.subfolder ? clip_image.subfolder + '/' : '') + clip_image.filename + (clip_image.type ? ` [${clip_image.type}]` : ''); } else { node.widgets[index].value = clip_image; } } } - if(ComfyApp.clipspace.widgets) { + if (ComfyApp.clipspace.widgets) { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); if (prop && prop.type != 'button') { - if(prop.type != 'image' && typeof prop.value == "string" && value.filename) { - prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:''); + if (prop.type != 'image' && typeof prop.value == "string" && value.filename) { + prop.value = (value.subfolder ? value.subfolder + '/' : '') + value.filename + (value.type ? ` [${value.type}]` : ''); } else { prop.value = value; @@ -272,6 +273,71 @@ export class ComfyApp { * @param {*} node The node to add the menu handler */ #addNodeContextMenuHandler(node) { + function getCopyImageOption(img) { + if (typeof window.ClipboardItem === "undefined") return []; + return [ + { + content: "Copy Image", + callback: async () => { + const url = new URL(img.src); + url.searchParams.delete("preview"); + + const writeImage = async (blob) => { + await navigator.clipboard.write([ + new ClipboardItem({ + [blob.type]: blob, + }), + ]); + }; + + try { + const data = await fetch(url); + const blob = await data.blob(); + try { + await writeImage(blob); + } catch (error) { + // Chrome seems to only support PNG on write, convert and try again + if (blob.type !== "image/png") { + const canvas = $el("canvas", { + width: img.naturalWidth, + height: img.naturalHeight, + }); + const ctx = canvas.getContext("2d"); + let image; + if (typeof window.createImageBitmap === "undefined") { + image = new Image(); + const p = new Promise((resolve, reject) => { + image.onload = resolve; + image.onerror = reject; + }).finally(() => { + URL.revokeObjectURL(image.src); + }); + image.src = URL.createObjectURL(blob); + await p; + } else { + image = await createImageBitmap(blob); + } + try { + ctx.drawImage(image, 0, 0); + canvas.toBlob(writeImage, "image/png"); + } finally { + if (typeof image.close === "function") { + image.close(); + } + } + + return; + } + throw error; + } + } catch (error) { + alert("Error copying image: " + (error.message ?? error)); + } + }, + }, + ]; + } + node.prototype.getExtraMenuOptions = function (_, options) { if (this.imgs) { // If this node has images then we add an open in new tab item @@ -289,16 +355,17 @@ export class ComfyApp { content: "Open Image", callback: () => { let url = new URL(img.src); - url.searchParams.delete('preview'); - window.open(url, "_blank") + url.searchParams.delete("preview"); + window.open(url, "_blank"); }, }, + ...getCopyImageOption(img), { content: "Save Image", callback: () => { const a = document.createElement("a"); let url = new URL(img.src); - url.searchParams.delete('preview'); + url.searchParams.delete("preview"); a.href = url; a.setAttribute("download", new URLSearchParams(url.search).get("filename")); document.body.append(a); @@ -311,33 +378,41 @@ export class ComfyApp { } options.push({ - content: "Bypass", - callback: (obj) => { if (this.mode === 4) this.mode = 0; else this.mode = 4; this.graph.change(); } - }); + content: "Bypass", + callback: (obj) => { + if (this.mode === 4) this.mode = 0; + else this.mode = 4; + this.graph.change(); + }, + }); // prevent conflict of clipspace content - if(!ComfyApp.clipspace_return_node) { + if (!ComfyApp.clipspace_return_node) { options.push({ - content: "Copy (Clipspace)", - callback: (obj) => { ComfyApp.copyToClipspace(this); } - }); + content: "Copy (Clipspace)", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + }, + }); - if(ComfyApp.clipspace != null) { + if (ComfyApp.clipspace != null) { options.push({ - content: "Paste (Clipspace)", - callback: () => { ComfyApp.pasteFromClipspace(this); } - }); + content: "Paste (Clipspace)", + callback: () => { + ComfyApp.pasteFromClipspace(this); + }, + }); } - if(ComfyApp.isImageNode(this)) { + if (ComfyApp.isImageNode(this)) { options.push({ - content: "Open in MaskEditor", - callback: (obj) => { - ComfyApp.copyToClipspace(this); - ComfyApp.clipspace_return_node = this; - ComfyApp.open_maskeditor(); - } - }); + content: "Open in MaskEditor", + callback: (obj) => { + ComfyApp.copyToClipspace(this); + ComfyApp.clipspace_return_node = this; + ComfyApp.open_maskeditor(); + }, + }); } } }; @@ -347,7 +422,7 @@ export class ComfyApp { const app = this; const origNodeOnKeyDown = node.prototype.onKeyDown; - node.prototype.onKeyDown = function(e) { + node.prototype.onKeyDown = function (e) { if (origNodeOnKeyDown && origNodeOnKeyDown.apply(this, e) === false) { return false; } @@ -402,7 +477,7 @@ export class ComfyApp { if (w.computeSize) { shiftY += w.computeSize()[1] + 4; } - else if(w.computedHeight) { + else if (w.computedHeight) { shiftY += w.computedHeight; } else { @@ -416,7 +491,7 @@ export class ComfyApp { } node.prototype.setSizeForImage = function (force) { - if(!force && this.animatedImages) return; + if (!force && this.animatedImages) return; if (this.inputHeight || this.freeWidgetSpace > 210) { this.setSize(this.size); @@ -443,8 +518,8 @@ export class ComfyApp { output.images.map((params) => { return api.apiURL( "/view?" + - new URLSearchParams(params).toString() + - (this.animatedImages ? "" : app.getPreviewFormatParam()) + app.getRandParam() + new URLSearchParams(params).toString() + + (this.animatedImages ? "" : app.getPreviewFormatParam()) + app.getRandParam() ); }) ); @@ -509,17 +584,17 @@ export class ComfyApp { } } - const cell_size = Math.min(w/columns, h/rows); - return {cell_size, columns, rows}; + const cell_size = Math.min(w / columns, h / rows); + return { cell_size, columns, rows }; } function is_all_same_aspect_ratio(imgs) { // assume: imgs.length >= 2 - let ratio = imgs[0].naturalWidth/imgs[0].naturalHeight; + let ratio = imgs[0].naturalWidth / imgs[0].naturalHeight; - for(let i=1; i w.name === ANIM_PREVIEW_WIDGET); - - if(this.animatedImages) { + + if (this.animatedImages) { // Instead of using the canvas we'll use a IMG - if(widgetIdx > -1) { + if (widgetIdx > -1) { // Replace content const widget = this.widgets[widgetIdx]; widget.options.host.updateImages(this.imgs); @@ -581,7 +656,7 @@ export class ComfyApp { var cellWidth, cellHeight, shiftX, cell_padding, cols; const compact_mode = is_all_same_aspect_ratio(this.imgs); - if(!compact_mode) { + if (!compact_mode) { // use rectangle cell style and border line cell_padding = 2; const { cell_size, columns, rows } = calculateGrid(dw, dh, numImages); @@ -589,8 +664,8 @@ export class ComfyApp { cellWidth = cell_size; cellHeight = cell_size; - shiftX = (dw-cell_size*cols)/2; - shiftY = (dh-cell_size*rows)/2 + top; + shiftX = (dw - cell_size * cols) / 2; + shiftY = (dh - cell_size * rows) / 2 + top; } else { cell_padding = 0; @@ -629,21 +704,21 @@ export class ComfyApp { } this.imageRects.push([x, y, cellWidth, cellHeight]); - let wratio = cellWidth/img.width; - let hratio = cellHeight/img.height; + let wratio = cellWidth / img.width; + let hratio = cellHeight / img.height; var ratio = Math.min(wratio, hratio); let imgHeight = ratio * img.height; - let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight)/2; + let imgY = row * cellHeight + shiftY + (cellHeight - imgHeight) / 2; let imgWidth = ratio * img.width; - let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth)/2; + let imgX = col * cellWidth + shiftX + (cellWidth - imgWidth) / 2; - ctx.drawImage(img, imgX+cell_padding, imgY+cell_padding, imgWidth-cell_padding*2, imgHeight-cell_padding*2); - if(!compact_mode) { + ctx.drawImage(img, imgX + cell_padding, imgY + cell_padding, imgWidth - cell_padding * 2, imgHeight - cell_padding * 2); + if (!compact_mode) { // rectangle cell and border line style ctx.strokeStyle = "#8F8F8F"; ctx.lineWidth = 1; - ctx.strokeRect(x+cell_padding, y+cell_padding, cellWidth-cell_padding*2, cellHeight-cell_padding*2); + ctx.strokeRect(x + cell_padding, y + cell_padding, cellWidth - cell_padding * 2, cellHeight - cell_padding * 2); } ctx.filter = "none"; @@ -737,7 +812,7 @@ export class ComfyApp { } // Dragging from Chrome->Firefox there is a file but its a bmp, so ignore that if (event.dataTransfer.files.length && event.dataTransfer.files[0].type !== "image/bmp") { - await this.handleFile(event.dataTransfer.files[0]); + await this.handleFile(event.dataTransfer.files[0]); } else { // Try loading the first URI in the transfer list const validTypes = ["text/uri-list", "text/x-moz-url"]; @@ -789,7 +864,7 @@ export class ComfyApp { document.addEventListener("paste", async (e) => { // ctrl+shift+v is used to paste nodes with connections // this is handled by litegraph - if(this.shiftDown) return; + if (this.shiftDown) return; let data = (e.clipboardData || window.clipboardData); const items = data.items; @@ -830,7 +905,7 @@ export class ComfyApp { data = data.slice(data.indexOf("workflow\n")); data = data.slice(data.indexOf("{")); workflow = JSON.parse(data); - } catch (error) {} + } catch (error) { } } if (workflow && workflow.version && workflow.nodes && workflow.extra) { @@ -881,7 +956,7 @@ export class ComfyApp { const self = this; const origProcessMouseDown = LGraphCanvas.prototype.processMouseDown; - LGraphCanvas.prototype.processMouseDown = function(e) { + LGraphCanvas.prototype.processMouseDown = function (e) { const res = origProcessMouseDown.apply(this, arguments); this.selected_group_moving = false; @@ -901,7 +976,7 @@ export class ComfyApp { } const origProcessMouseMove = LGraphCanvas.prototype.processMouseMove; - LGraphCanvas.prototype.processMouseMove = function(e) { + LGraphCanvas.prototype.processMouseMove = function (e) { const orig_selected_group = this.selected_group; if (this.selected_group && !this.selected_group_resizing && !this.selected_group_moving) { @@ -926,7 +1001,7 @@ export class ComfyApp { #addProcessKeyHandler() { const self = this; const origProcessKey = LGraphCanvas.prototype.processKey; - LGraphCanvas.prototype.processKey = function(e) { + LGraphCanvas.prototype.processKey = function (e) { if (!this.graph) { return; } @@ -1010,7 +1085,7 @@ export class ComfyApp { const self = this; const origDrawGroups = LGraphCanvas.prototype.drawGroups; - LGraphCanvas.prototype.drawGroups = function(canvas, ctx) { + LGraphCanvas.prototype.drawGroups = function (canvas, ctx) { if (!this.graph) { return; } @@ -1097,7 +1172,7 @@ export class ComfyApp { 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT, [this.round_radius * 2, this.round_radius * 2, 2, 2] - ); + ); else if (shape == LiteGraph.CIRCLE_SHAPE) ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2); ctx.strokeStyle = color; @@ -1264,9 +1339,9 @@ export class ComfyApp { for (const node of app.graph._nodes) { node.onGraphConfigured?.(); } - + const r = onConfigure?.apply(this, arguments); - + // Fire after onConfigure, used by primitves to generate widget using input nodes config for (const node of app.graph._nodes) { node.onAfterGraphConfigured?.(); @@ -1280,24 +1355,106 @@ export class ComfyApp { * Loads all extensions from the API into the window in parallel */ async #loadExtensions() { - const extensions = await api.getExtensions(); - this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions }); - - const extensionPromises = extensions.map(async ext => { - try { - await import(api.apiURL(ext)); - } catch (error) { - console.error("Error loading extension", ext, error); - } - }); - - await Promise.all(extensionPromises); + const extensions = await api.getExtensions(); + this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions }); + + const extensionPromises = extensions.map(async ext => { + try { + await import(api.apiURL(ext)); + } catch (error) { + console.error("Error loading extension", ext, error); + } + }); + + await Promise.all(extensionPromises); + } + + async #migrateSettings() { + this.isNewUserSession = true; + // Store all current settings + const settings = Object.keys(this.ui.settings).reduce((p, n) => { + const v = localStorage[`Comfy.Settings.${n}`]; + if (v) { + try { + p[n] = JSON.parse(v); + } catch (error) { } + } + return p; + }, {}); + + await api.storeSettings(settings); + } + + async #setUser() { + const userConfig = await api.getUserConfig(); + this.storageLocation = userConfig.storage; + if (typeof userConfig.migrated == "boolean") { + // Single user mode migrated true/false for if the default user is created + if (!userConfig.migrated && this.storageLocation === "server") { + // Default user not created yet + await this.#migrateSettings(); + } + return; + } + + this.multiUserServer = true; + let user = localStorage["Comfy.userId"]; + const users = userConfig.users ?? {}; + if (!user || !users[user]) { + // This will rarely be hit so move the loading to on demand + const { UserSelectionScreen } = await import("./ui/userSelection.js"); + + this.ui.menuContainer.style.display = "none"; + const { userId, username, created } = await new UserSelectionScreen().show(users, user); + this.ui.menuContainer.style.display = ""; + + user = userId; + localStorage["Comfy.userName"] = username; + localStorage["Comfy.userId"] = user; + + if (created) { + api.user = user; + await this.#migrateSettings(); + } + } + + api.user = user; + + this.ui.settings.addSetting({ + id: "Comfy.SwitchUser", + name: "Switch User", + type: (name) => { + let currentUser = localStorage["Comfy.userName"]; + if (currentUser) { + currentUser = ` (${currentUser})`; + } + return $el("tr", [ + $el("td", [ + $el("label", { + textContent: name, + }), + ]), + $el("td", [ + $el("button", { + textContent: name + (currentUser ?? ""), + onclick: () => { + delete localStorage["Comfy.userId"]; + delete localStorage["Comfy.userName"]; + window.location.reload(); + }, + }), + ]), + ]); + }, + }); } /** * Set up the app on the page */ async setup() { + await this.#setUser(); + await this.ui.settings.load(); await this.#loadExtensions(); // Create and mount the LiteGraph in the DOM @@ -1341,34 +1498,49 @@ export class ComfyApp { await this.#invokeExtensionsAsync("init"); await this.registerNodes(); + initWidgets(this); - // Load previous workflow - let restored = false; - try { - const workflow = await getWorkflow(); - const json = localStorage.getItem("workflow"); - if (workflow) { - await this.loadGraphData(workflow); - restored = true; - } else { - if (json) { - const workflow = JSON.parse(json); - await this.loadGraphData(workflow); - restored = true; - } - } - } catch (err) { - console.error("Error loading previous workflow", err); + // Load prebuilt workflow + const workflow = await getWorkflow(); + + if (workflow) { + await this.loadGraphData(workflow); } + else { + // Load previous workflow + let restored = false; + try { + const loadWorkflow = async (json) => { + if (json) { + const workflow = JSON.parse(json); + await this.loadGraphData(workflow); + return true; + } + }; + const clientId = api.initialClientId ?? api.clientId; + restored = + (clientId && (await loadWorkflow(sessionStorage.getItem(`workflow:${clientId}`)))) || + (await loadWorkflow(localStorage.getItem("workflow"))); + } catch (err) { + console.error("Error loading previous workflow", err); + } - // We failed to restore a workflow so load the default - if (!restored) { - await this.loadGraphData(); + // We failed to restore a workflow so load the default + if (!restored) { + await this.loadGraphData(); + } } + // Save current workflow automatically - setInterval(() => localStorage.setItem("workflow", JSON.stringify(this.graph.serialize())), 1000); + setInterval(() => { + const workflow = JSON.stringify(this.graph.serialize()); + localStorage.setItem("workflow", workflow); + if (api.clientId) { + sessionStorage.setItem(`workflow:${api.clientId}`, workflow); + } + }, 1000); this.#addDrawNodeHandler(); this.#addDrawGroupsHandler(); @@ -1420,8 +1592,8 @@ export class ComfyApp { let widgetCreated = true; const widgetType = self.getWidgetType(inputData, inputName); - if(widgetType) { - if(widgetType === "COMBO") { + if (widgetType) { + if (widgetType === "COMBO") { Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {}); } else { Object.assign(config, self.widgets[widgetType](this, inputName, inputData, app) || {}); @@ -1432,11 +1604,11 @@ export class ComfyApp { widgetCreated = false; } - if(widgetCreated && inputData[1]?.forceInput && config?.widget) { + if (widgetCreated && inputData[1]?.forceInput && config?.widget) { if (!config.widget.options) config.widget.options = {}; config.widget.options.forceInput = inputData[1].forceInput; } - if(widgetCreated && inputData[1]?.defaultInput && config?.widget) { + if (widgetCreated && inputData[1]?.defaultInput && config?.widget) { if (!config.widget.options) config.widget.options = {}; config.widget.options.defaultInput = inputData[1].defaultInput; } @@ -1444,9 +1616,9 @@ export class ComfyApp { for (const o in nodeData["output"]) { let output = nodeData["output"][o]; - if(output instanceof Array) output = "COMBO"; + if (output instanceof Array) output = "COMBO"; const outputName = nodeData["output_name"][o] || output; - const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE; this.addOutput(outputName, output, { shape: outputShape }); } @@ -1475,7 +1647,7 @@ export class ComfyApp { node.category = nodeData.category; } - async registerNodesFromDefs(defs) { + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets @@ -1539,7 +1711,7 @@ export class ComfyApp { Array.from(new Set(missingNodeTypes)).map((t) => { let children = []; if (typeof t === "object") { - if(seenTypes.has(t.type)) return null; + if (seenTypes.has(t.type)) return null; seenTypes.add(t.type); children.push($el("span", { textContent: t.type })); if (t.hint) { @@ -1549,7 +1721,7 @@ export class ComfyApp { children.push($el("button", { onclick: t.action.callback, textContent: t.action.text })); } } else { - if(seenTypes.has(t)) return null; + if (seenTypes.has(t)) return null; seenTypes.add(t); children.push($el("span", { textContent: t })); } @@ -1582,11 +1754,9 @@ export class ComfyApp { reset_invalid_values = true; } - if (typeof structuredClone === "undefined") - { + if (typeof structuredClone === "undefined") { graphData = JSON.parse(JSON.stringify(graphData)); - }else - { + } else { graphData = structuredClone(graphData); } @@ -1702,6 +1872,14 @@ export class ComfyApp { */ async graphToPrompt() { for (const outerNode of this.graph.computeExecutionOrder(false)) { + if (outerNode.widgets) { + for (const widget of outerNode.widgets) { + // Allow widgets to run callbacks before a prompt has been queued + // e.g. random seed before every gen + widget.beforeQueued?.(); + } + } + const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; for (const node of innerNodes) { if (node.isVirtualNode) { @@ -1791,10 +1969,19 @@ export class ComfyApp { } } - output[String(node.id)] = { + let node_data = { inputs, class_type: node.comfyClass, }; + + if (this.ui.settings.getSettingValue("Comfy.DevMode")) { + // Ignored by the backend. + node_data["_meta"] = { + title: node.title, + } + } + + output[String(node.id)] = node_data; } } @@ -1826,9 +2013,9 @@ export class ComfyApp { else if (error.response) { let message = error.response.error.message; if (error.response.error.details) - message += ": " + error.response.error.details; + message += ": " + error.response.error.details; for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) { - message += "\n" + nodeError.class_type + ":" + message += "\n" + nodeError.class_type + ":" for (const errorReason of nodeError.errors) { message += "\n - " + errorReason.message + ": " + errorReason.details } @@ -1903,6 +2090,7 @@ export class ComfyApp { } finally { this.#processingQueue = false; } + api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } })); } /** @@ -1930,6 +2118,8 @@ export class ComfyApp { this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node. } else if (pngInfo.prompt) { this.loadApiJson(JSON.parse(pngInfo.prompt)); + } else if (pngInfo.Prompt) { + this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node. } } } else if (file.type === "application/json" || file.name?.endsWith(".json")) { @@ -1938,7 +2128,7 @@ export class ComfyApp { const jsonContent = JSON.parse(reader.result); if (jsonContent?.templates) { this.loadTemplateData(jsonContent); - } else if(this.isApiJson(jsonContent)) { + } else if (this.isApiJson(jsonContent)) { this.loadApiJson(jsonContent); } else { await this.loadGraphData(jsonContent); @@ -1983,8 +2173,17 @@ export class ComfyApp { if (value instanceof Array) { const [fromId, fromSlot] = value; const fromNode = app.graph.getNodeById(fromId); - const toSlot = node.inputs?.findIndex((inp) => inp.name === input); - if (toSlot !== -1) { + let toSlot = node.inputs?.findIndex((inp) => inp.name === input); + if (toSlot == null || toSlot === -1) { + try { + // Target has no matching input, most likely a converted widget + const widget = node.widgets?.find((w) => w.name === input); + if (widget && node.convertWidgetToInput?.(widget)) { + toSlot = node.inputs?.length - 1; + } + } catch (error) {} + } + if (toSlot != null || toSlot !== -1) { fromNode.connect(fromSlot, node, toSlot); } } else { @@ -2020,36 +2219,34 @@ export class ComfyApp { async refreshComboInNodes() { const defs = await api.getNodeDefs(); - for(const nodeId in LiteGraph.registered_node_types) { - const node = LiteGraph.registered_node_types[nodeId]; - const nodeDef = defs[nodeId]; - if(!nodeDef) continue; - - node.nodeData = nodeDef; + for (const nodeId in defs) { + this.registerNodeDef(nodeId, defs[nodeId]); } - for(let nodeNum in this.graph._nodes) { + for (let nodeNum in this.graph._nodes) { const node = this.graph._nodes[nodeNum]; const def = defs[node.type]; // Allow primitive nodes to handle refresh node.refreshComboInNode?.(defs); - if(!def) + if (!def) continue; - for(const widgetNum in node.widgets) { + for (const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] - if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { + if (widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { widget.options.values = def["input"]["required"][widget.name][0]; - if(widget.name != 'image' && !widget.options.values.includes(widget.value)) { + if (widget.name != 'image' && !widget.options.values.includes(widget.value)) { widget.value = widget.options.values[0]; widget.callback(widget.value); } } } } + + await this.#invokeExtensionsAsync("refreshComboInNodes", defs); } /** diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index bb4c892b541..d5eeebdbd39 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -177,7 +177,7 @@ LGraphCanvas.prototype.computeVisibleNodes = function () { for (const w of node.widgets) { if (w.element) { w.element.hidden = hidden; - w.element.style.display = hidden ? "none" : null; + w.element.style.display = hidden ? "none" : undefined; if (hidden) { w.options.onHide?.(w); } @@ -239,7 +239,8 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { node.flags?.collapsed || (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.computedHeight <= 0 || - widget.type === "converted-widget"; + widget.type === "converted-widget"|| + widget.type === "hidden"; element.hidden = hidden; element.style.display = hidden ? "none" : null; if (hidden) { diff --git a/web/scripts/logging.js b/web/scripts/logging.js index c73462e1ea3..875dd970bc8 100644 --- a/web/scripts/logging.js +++ b/web/scripts/logging.js @@ -269,6 +269,9 @@ export class ComfyLogging { id: settingId, name: settingId, defaultValue: true, + onChange: (value) => { + this.enabled = value; + }, type: (name, setter, value) => { return $el("tr", [ $el("td", [ @@ -283,7 +286,7 @@ export class ComfyLogging { type: "checkbox", checked: value, onchange: (event) => { - setter((this.enabled = event.target.checked)); + setter(event.target.checked); }, }), $el("button", { diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 83a4ebc86c4..1696092098f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -24,7 +24,7 @@ export function getPngMetadata(file) { const length = dataView.getUint32(offset); // Get the chunk type const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); - if (type === "tEXt" || type == "comf") { + if (type === "tEXt" || type == "comf" || type === "iTXt") { // Get the keyword let keyword_end = offset + 8; while (pngData[keyword_end] !== 0) { @@ -33,7 +33,7 @@ export function getPngMetadata(file) { const keyword = String.fromCharCode(...pngData.slice(offset + 8, keyword_end)); // Get the text const contentArraySegment = pngData.slice(keyword_end + 1, offset + 8 + length); - const contentJson = Array.from(contentArraySegment).map(s=>String.fromCharCode(s)).join('') + const contentJson = new TextDecoder("utf-8").decode(contentArraySegment); txt_chunks[keyword] = contentJson; } diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 55be5b228cc..63e6d1b0fcc 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -1,6 +1,26 @@ -import {api} from "./api.js"; + +import { api } from "./api.js"; +import { ComfyDialog as _ComfyDialog } from "./ui/dialog.js"; +import { toggleSwitch } from "./ui/toggleSwitch.js"; +import { ComfySettingsDialog } from "./ui/settings.js"; import { getUserId } from "./utils.js"; +export const ComfyDialog = _ComfyDialog; + +/** + * + * @param { string } tag HTML Element Tag and optional classes e.g. div.class1.class2 + * @param { string | Element | Element[] | { + * parent?: Element, + * $?: (el: Element) => void, + * dataset?: DOMStringMap, + * style?: CSSStyleDeclaration, + * for?: string + * } | undefined } propsOrChildren + * @param { Element[] | undefined } [children] + * @returns + */ + export function $el(tag, propsOrChildren, children) { const split = tag.split("."); const element = document.createElement(split.shift()); @@ -9,6 +29,11 @@ export function $el(tag, propsOrChildren, children) { } if (propsOrChildren) { + if (typeof propsOrChildren === "string") { + propsOrChildren = { textContent: propsOrChildren }; + } else if (propsOrChildren instanceof Element) { + propsOrChildren = [propsOrChildren]; + } if (Array.isArray(propsOrChildren)) { element.append(...propsOrChildren); } else { @@ -32,7 +57,7 @@ export function $el(tag, propsOrChildren, children) { Object.assign(element, propsOrChildren); if (children) { - element.append(...children); + element.append(...(children instanceof Array ? children : [children])); } if (parent) { @@ -168,267 +193,6 @@ function dragElement(dragEl, settings) { } } -export class ComfyDialog { - constructor() { - this.element = $el("div.comfy-modal", {parent: document.body}, [ - $el("div.comfy-modal-content", [$el("p", {$: (p) => (this.textElement = p)}), ...this.createButtons()]), - ]); - } - - createButtons() { - return [ - $el("button", { - type: "button", - textContent: "Close", - onclick: () => this.close(), - }), - ]; - } - - close() { - this.element.style.display = "none"; - } - - show(html) { - if (typeof html === "string") { - this.textElement.innerHTML = html; - } else { - this.textElement.replaceChildren(html); - } - this.element.style.display = "flex"; - } -} - -class ComfySettingsDialog extends ComfyDialog { - constructor() { - super(); - this.element = $el("dialog", { - id: "comfy-settings-dialog", - parent: document.body, - }, [ - $el("table.comfy-modal-content.comfy-table", [ - $el("caption", {textContent: "Settings"}), - $el("tbody", {$: (tbody) => (this.textElement = tbody)}), - $el("button", { - type: "button", - textContent: "Close", - style: { - cursor: "pointer", - }, - onclick: () => { - this.element.close(); - }, - }), - ]), - ]); - this.settings = []; - } - - getSettingValue(id, defaultValue) { - const settingId = "Comfy.Settings." + id; - const v = localStorage[settingId]; - return v == null ? defaultValue : JSON.parse(v); - } - - setSettingValue(id, value) { - const settingId = "Comfy.Settings." + id; - localStorage[settingId] = JSON.stringify(value); - } - - addSetting({id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined}) { - if (!id) { - throw new Error("Settings must have an ID"); - } - - if (this.settings.find((s) => s.id === id)) { - throw new Error(`Setting ${id} of type ${type} must have a unique ID.`); - } - - const settingId = `Comfy.Settings.${id}`; - const v = localStorage[settingId]; - let value = v == null ? defaultValue : JSON.parse(v); - - // Trigger initial setting of value - if (onChange) { - onChange(value, undefined); - } - - this.settings.push({ - render: () => { - const setter = (v) => { - if (onChange) { - onChange(v, value); - } - localStorage[settingId] = JSON.stringify(v); - value = v; - }; - value = this.getSettingValue(id, defaultValue); - - let element; - const htmlID = id.replaceAll(".", "-"); - - const labelCell = $el("td", [ - $el("label", { - for: htmlID, - classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""], - textContent: name, - }) - ]); - - if (typeof type === "function") { - element = type(name, setter, value, attrs); - } else { - switch (type) { - case "boolean": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - id: htmlID, - type: "checkbox", - checked: value, - onchange: (event) => { - const isChecked = event.target.checked; - if (onChange !== undefined) { - onChange(isChecked) - } - this.setSettingValue(id, isChecked); - }, - }), - ]), - ]) - break; - case "number": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - type, - value, - id: htmlID, - oninput: (e) => { - setter(e.target.value); - }, - ...attrs - }), - ]), - ]); - break; - case "slider": - element = $el("tr", [ - labelCell, - $el("td", [ - $el("div", { - style: { - display: "grid", - gridAutoFlow: "column", - }, - }, [ - $el("input", { - ...attrs, - value, - type: "range", - oninput: (e) => { - setter(e.target.value); - e.target.nextElementSibling.value = e.target.value; - }, - }), - $el("input", { - ...attrs, - value, - id: htmlID, - type: "number", - style: {maxWidth: "4rem"}, - oninput: (e) => { - setter(e.target.value); - e.target.previousElementSibling.value = e.target.value; - }, - }), - ]), - ]), - ]); - break; - case "combo": - element = $el("tr", [ - labelCell, - $el("td", [ - $el( - "select", - { - oninput: (e) => { - setter(e.target.value); - }, - }, - (typeof options === "function" ? options(value) : options || []).map((opt) => { - if (typeof opt === "string") { - opt = { text: opt }; - } - const v = opt.value ?? opt.text; - return $el("option", { - value: v, - textContent: opt.text, - selected: value + "" === v + "", - }); - }) - ), - ]), - ]); - break; - case "text": - default: - if (type !== "text") { - console.warn(`Unsupported setting type '${type}, defaulting to text`); - } - - element = $el("tr", [ - labelCell, - $el("td", [ - $el("input", { - value, - id: htmlID, - oninput: (e) => { - setter(e.target.value); - }, - ...attrs, - }), - ]), - ]); - break; - } - } - if (tooltip) { - element.title = tooltip; - } - - return element; - }, - }); - - const self = this; - return { - get value() { - return self.getSettingValue(id, defaultValue); - }, - set value(v) { - self.setSettingValue(id, v); - }, - }; - } - - show() { - this.textElement.replaceChildren( - $el("tr", { - style: {display: "none"}, - }, [ - $el("th"), - $el("th", {style: {width: "33%"}}) - ]), - ...this.settings.map((s) => s.render()), - ) - this.element.showModal(); - } -} - class ComfyList { #type; #text; @@ -539,7 +303,7 @@ export class ComfyUI { constructor(app) { this.app = app; this.dialog = new ComfyDialog(); - this.settings = new ComfySettingsDialog(); + this.settings = new ComfySettingsDialog(app); this.batchCount = 1; this.lastQueueSize = 0; @@ -620,18 +384,68 @@ export class ComfyUI { }, }); - this.menuContainer = $el("div.comfy-menu", {parent: document.body}, [ - $el("div.drag-handle", { + const autoQueueModeEl = toggleSwitch( + "autoQueueMode", + [ + { text: "instant", tooltip: "A new prompt will be queued as soon as the queue reaches 0" }, + { text: "change", tooltip: "A new prompt will be queued when the queue is at 0 and the graph is/has changed" }, + ], + { + onChange: (value) => { + this.autoQueueMode = value.item.value; + }, + } + ); + autoQueueModeEl.style.display = "none"; + + api.addEventListener("graphChanged", () => { + if (this.autoQueueMode === "change" && this.autoQueueEnabled === true) { + if (this.lastQueueSize === 0) { + this.graphHasChanged = false; + app.queuePrompt(0, this.batchCount); + } else { + this.graphHasChanged = true; + } + } + }); + + this.menuHamburger = $el( + "div.comfy-menu-hamburger", + { + parent: document.body, + onclick: () => { + this.menuContainer.style.display = "block"; + this.menuHamburger.style.display = "none"; + }, + }, + [$el("div"), $el("div"), $el("div")] + ); + + this.menuContainer = $el("div.comfy-menu", { parent: document.body }, [ + $el("div.drag-handle.comfy-menu-header", { style: { overflow: "hidden", position: "relative", width: "100%", cursor: "default" } - }, [ + }, [ $el("span.drag-handle"), - $el("span", {$: (q) => (this.queueSize = q)}), - // $el("button.comfy-settings-btn", {textContent: "⚙️", onclick: () => this.settings.show()}), + $el("span.comfy-menu-queue-size", { $: (q) => (this.queueSize = q) }), + $el("div.comfy-menu-actions", [ +// $el("button.comfy-settings-btn", { +// textContent: "⚙️", +// onclick: () => this.settings.show(), +// }), + $el("button.comfy-close-menu-btn", { + textContent: "\u00d7", + onclick: () => { + this.menuContainer.style.display = "none"; + this.menuHamburger.style.display = "flex"; + }, + }), + ]), + ]), $el("button.comfy-queue-btn", { id: "queue-button", @@ -677,20 +491,22 @@ export class ComfyUI { }, }), ]), - $el("div",[ $el("label",{ for:"autoQueueCheckbox", innerHTML: "Auto Queue" - // textContent: "Auto Queue" }), $el("input", { id: "autoQueueCheckbox", type: "checkbox", checked: false, title: "Automatically queue prompt when the queue size hits 0", - + onchange: (e) => { + this.autoQueueEnabled = e.target.checked; + autoQueueModeEl.style.display = this.autoQueueEnabled ? "" : "none"; + } }), + autoQueueModeEl ]) ]), $el("div.comfy-menu-btns", [ @@ -829,10 +645,13 @@ export class ComfyUI { if ( this.lastQueueSize != 0 && status.exec_info.queue_remaining == 0 && - document.getElementById("autoQueueCheckbox").checked && - ! app.lastExecutionError + this.autoQueueEnabled && + (this.autoQueueMode === "instant" || this.graphHasChanged) && + !app.lastExecutionError ) { app.queuePrompt(0, this.batchCount); + status.exec_info.queue_remaining += this.batchCount; + this.graphHasChanged = false; } this.lastQueueSize = status.exec_info.queue_remaining; } diff --git a/web/scripts/ui/dialog.js b/web/scripts/ui/dialog.js new file mode 100644 index 00000000000..aee93b3c84f --- /dev/null +++ b/web/scripts/ui/dialog.js @@ -0,0 +1,32 @@ +import { $el } from "../ui.js"; + +export class ComfyDialog { + constructor() { + this.element = $el("div.comfy-modal", { parent: document.body }, [ + $el("div.comfy-modal-content", [$el("p", { $: (p) => (this.textElement = p) }), ...this.createButtons()]), + ]); + } + + createButtons() { + return [ + $el("button", { + type: "button", + textContent: "Close", + onclick: () => this.close(), + }), + ]; + } + + close() { + this.element.style.display = "none"; + } + + show(html) { + if (typeof html === "string") { + this.textElement.innerHTML = html; + } else { + this.textElement.replaceChildren(html); + } + this.element.style.display = "flex"; + } +} diff --git a/web/scripts/ui/draggableList.js b/web/scripts/ui/draggableList.js new file mode 100644 index 00000000000..d535948869f --- /dev/null +++ b/web/scripts/ui/draggableList.js @@ -0,0 +1,287 @@ +// @ts-check +/* + Original implementation: + https://github.com/TahaSh/drag-to-reorder + MIT License + + Copyright (c) 2023 Taha Shashtari + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +import { $el } from "../ui.js"; + +$el("style", { + parent: document.head, + textContent: ` + .draggable-item { + position: relative; + will-change: transform; + user-select: none; + } + .draggable-item.is-idle { + transition: 0.25s ease transform; + } + .draggable-item.is-draggable { + z-index: 10; + } + ` +}); + +export class DraggableList extends EventTarget { + listContainer; + draggableItem; + pointerStartX; + pointerStartY; + scrollYMax; + itemsGap = 0; + items = []; + itemSelector; + handleClass = "drag-handle"; + off = []; + offDrag = []; + + constructor(element, itemSelector) { + super(); + this.listContainer = element; + this.itemSelector = itemSelector; + + if (!this.listContainer) return; + + this.off.push(this.on(this.listContainer, "mousedown", this.dragStart)); + this.off.push(this.on(this.listContainer, "touchstart", this.dragStart)); + this.off.push(this.on(document, "mouseup", this.dragEnd)); + this.off.push(this.on(document, "touchend", this.dragEnd)); + } + + getAllItems() { + if (!this.items?.length) { + this.items = Array.from(this.listContainer.querySelectorAll(this.itemSelector)); + this.items.forEach((element) => { + element.classList.add("is-idle"); + }); + } + return this.items; + } + + getIdleItems() { + return this.getAllItems().filter((item) => item.classList.contains("is-idle")); + } + + isItemAbove(item) { + return item.hasAttribute("data-is-above"); + } + + isItemToggled(item) { + return item.hasAttribute("data-is-toggled"); + } + + on(source, event, listener, options) { + listener = listener.bind(this); + source.addEventListener(event, listener, options); + return () => source.removeEventListener(event, listener); + } + + dragStart(e) { + if (e.target.classList.contains(this.handleClass)) { + this.draggableItem = e.target.closest(this.itemSelector); + } + + if (!this.draggableItem) return; + + this.pointerStartX = e.clientX || e.touches[0].clientX; + this.pointerStartY = e.clientY || e.touches[0].clientY; + this.scrollYMax = this.listContainer.scrollHeight - this.listContainer.clientHeight; + + this.setItemsGap(); + this.initDraggableItem(); + this.initItemsState(); + + this.offDrag.push(this.on(document, "mousemove", this.drag)); + this.offDrag.push(this.on(document, "touchmove", this.drag, { passive: false })); + + this.dispatchEvent( + new CustomEvent("dragstart", { + detail: { element: this.draggableItem, position: this.getAllItems().indexOf(this.draggableItem) }, + }) + ); + } + + setItemsGap() { + if (this.getIdleItems().length <= 1) { + this.itemsGap = 0; + return; + } + + const item1 = this.getIdleItems()[0]; + const item2 = this.getIdleItems()[1]; + + const item1Rect = item1.getBoundingClientRect(); + const item2Rect = item2.getBoundingClientRect(); + + this.itemsGap = Math.abs(item1Rect.bottom - item2Rect.top); + } + + initItemsState() { + this.getIdleItems().forEach((item, i) => { + if (this.getAllItems().indexOf(this.draggableItem) > i) { + item.dataset.isAbove = ""; + } + }); + } + + initDraggableItem() { + this.draggableItem.classList.remove("is-idle"); + this.draggableItem.classList.add("is-draggable"); + } + + drag(e) { + if (!this.draggableItem) return; + + e.preventDefault(); + + const clientX = e.clientX || e.touches[0].clientX; + const clientY = e.clientY || e.touches[0].clientY; + + const listRect = this.listContainer.getBoundingClientRect(); + + if (clientY > listRect.bottom) { + if (this.listContainer.scrollTop < this.scrollYMax) { + this.listContainer.scrollBy(0, 10); + this.pointerStartY -= 10; + } + } else if (clientY < listRect.top && this.listContainer.scrollTop > 0) { + this.pointerStartY += 10; + this.listContainer.scrollBy(0, -10); + } + + const pointerOffsetX = clientX - this.pointerStartX; + const pointerOffsetY = clientY - this.pointerStartY; + + this.updateIdleItemsStateAndPosition(); + this.draggableItem.style.transform = `translate(${pointerOffsetX}px, ${pointerOffsetY}px)`; + } + + updateIdleItemsStateAndPosition() { + const draggableItemRect = this.draggableItem.getBoundingClientRect(); + const draggableItemY = draggableItemRect.top + draggableItemRect.height / 2; + + // Update state + this.getIdleItems().forEach((item) => { + const itemRect = item.getBoundingClientRect(); + const itemY = itemRect.top + itemRect.height / 2; + if (this.isItemAbove(item)) { + if (draggableItemY <= itemY) { + item.dataset.isToggled = ""; + } else { + delete item.dataset.isToggled; + } + } else { + if (draggableItemY >= itemY) { + item.dataset.isToggled = ""; + } else { + delete item.dataset.isToggled; + } + } + }); + + // Update position + this.getIdleItems().forEach((item) => { + if (this.isItemToggled(item)) { + const direction = this.isItemAbove(item) ? 1 : -1; + item.style.transform = `translateY(${direction * (draggableItemRect.height + this.itemsGap)}px)`; + } else { + item.style.transform = ""; + } + }); + } + + dragEnd() { + if (!this.draggableItem) return; + + this.applyNewItemsOrder(); + this.cleanup(); + } + + applyNewItemsOrder() { + const reorderedItems = []; + + let oldPosition = -1; + this.getAllItems().forEach((item, index) => { + if (item === this.draggableItem) { + oldPosition = index; + return; + } + if (!this.isItemToggled(item)) { + reorderedItems[index] = item; + return; + } + const newIndex = this.isItemAbove(item) ? index + 1 : index - 1; + reorderedItems[newIndex] = item; + }); + + for (let index = 0; index < this.getAllItems().length; index++) { + const item = reorderedItems[index]; + if (typeof item === "undefined") { + reorderedItems[index] = this.draggableItem; + } + } + + reorderedItems.forEach((item) => { + this.listContainer.appendChild(item); + }); + + this.items = reorderedItems; + + this.dispatchEvent( + new CustomEvent("dragend", { + detail: { element: this.draggableItem, oldPosition, newPosition: reorderedItems.indexOf(this.draggableItem) }, + }) + ); + } + + cleanup() { + this.itemsGap = 0; + this.items = []; + this.unsetDraggableItem(); + this.unsetItemState(); + + this.offDrag.forEach((f) => f()); + this.offDrag = []; + } + + unsetDraggableItem() { + this.draggableItem.style = null; + this.draggableItem.classList.remove("is-draggable"); + this.draggableItem.classList.add("is-idle"); + this.draggableItem = null; + } + + unsetItemState() { + this.getIdleItems().forEach((item, i) => { + delete item.dataset.isAbove; + delete item.dataset.isToggled; + item.style.transform = ""; + }); + } + + dispose() { + this.off.forEach((f) => f()); + } +} diff --git a/web/scripts/ui/settings.js b/web/scripts/ui/settings.js new file mode 100644 index 00000000000..9e9d13af00b --- /dev/null +++ b/web/scripts/ui/settings.js @@ -0,0 +1,317 @@ +import { $el } from "../ui.js"; +import { api } from "../api.js"; +import { ComfyDialog } from "./dialog.js"; + +export class ComfySettingsDialog extends ComfyDialog { + constructor(app) { + super(); + this.app = app; + this.settingsValues = {}; + this.settingsLookup = {}; + this.element = $el( + "dialog", + { + id: "comfy-settings-dialog", + parent: document.body, + }, + [ + $el("table.comfy-modal-content.comfy-table", [ + $el( + "caption", + { textContent: "Settings" }, + $el("button.comfy-btn", { + type: "button", + textContent: "\u00d7", + onclick: () => { + this.element.close(); + }, + }) + ), + $el("tbody", { $: (tbody) => (this.textElement = tbody) }), + $el("button", { + type: "button", + textContent: "Close", + style: { + cursor: "pointer", + }, + onclick: () => { + this.element.close(); + }, + }), + ]), + ] + ); + } + + get settings() { + return Object.values(this.settingsLookup); + } + + async load() { + if (this.app.storageLocation === "browser") { + this.settingsValues = localStorage; + } else { + this.settingsValues = await api.getSettings(); + } + + // Trigger onChange for any settings added before load + for (const id in this.settingsLookup) { + this.settingsLookup[id].onChange?.(this.settingsValues[this.getId(id)]); + } + } + + getId(id) { + if (this.app.storageLocation === "browser") { + id = "Comfy.Settings." + id; + } + return id; + } + + getSettingValue(id, defaultValue) { + let value = this.settingsValues[this.getId(id)]; + if(value != null) { + if(this.app.storageLocation === "browser") { + try { + value = JSON.parse(value); + } catch (error) { + } + } + } + return value ?? defaultValue; + } + + async setSettingValueAsync(id, value) { + const json = JSON.stringify(value); + localStorage["Comfy.Settings." + id] = json; // backwards compatibility for extensions keep setting in storage + + let oldValue = this.getSettingValue(id, undefined); + this.settingsValues[this.getId(id)] = value; + + if (id in this.settingsLookup) { + this.settingsLookup[id].onChange?.(value, oldValue); + } + + await api.storeSetting(id, value); + } + + setSettingValue(id, value) { + this.setSettingValueAsync(id, value).catch((err) => { + alert(`Error saving setting '${id}'`); + console.error(err); + }); + } + + addSetting({ id, name, type, defaultValue, onChange, attrs = {}, tooltip = "", options = undefined }) { + if (!id) { + throw new Error("Settings must have an ID"); + } + + if (id in this.settingsLookup) { + throw new Error(`Setting ${id} of type ${type} must have a unique ID.`); + } + + let skipOnChange = false; + let value = this.getSettingValue(id); + if (value == null) { + if (this.app.isNewUserSession) { + // Check if we have a localStorage value but not a setting value and we are a new user + const localValue = localStorage["Comfy.Settings." + id]; + if (localValue) { + value = JSON.parse(localValue); + this.setSettingValue(id, value); // Store on the server + } + } + if (value == null) { + value = defaultValue; + } + } + + // Trigger initial setting of value + if (!skipOnChange) { + onChange?.(value, undefined); + } + + this.settingsLookup[id] = { + id, + onChange, + name, + render: () => { + const setter = (v) => { + if (onChange) { + onChange(v, value); + } + + this.setSettingValue(id, v); + value = v; + }; + value = this.getSettingValue(id, defaultValue); + + let element; + const htmlID = id.replaceAll(".", "-"); + + const labelCell = $el("td", [ + $el("label", { + for: htmlID, + classList: [tooltip !== "" ? "comfy-tooltip-indicator" : ""], + textContent: name, + }), + ]); + + if (typeof type === "function") { + element = type(name, setter, value, attrs); + } else { + switch (type) { + case "boolean": + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + id: htmlID, + type: "checkbox", + checked: value, + onchange: (event) => { + const isChecked = event.target.checked; + if (onChange !== undefined) { + onChange(isChecked); + } + this.setSettingValue(id, isChecked); + }, + }), + ]), + ]); + break; + case "number": + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + type, + value, + id: htmlID, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs, + }), + ]), + ]); + break; + case "slider": + element = $el("tr", [ + labelCell, + $el("td", [ + $el( + "div", + { + style: { + display: "grid", + gridAutoFlow: "column", + }, + }, + [ + $el("input", { + ...attrs, + value, + type: "range", + oninput: (e) => { + setter(e.target.value); + e.target.nextElementSibling.value = e.target.value; + }, + }), + $el("input", { + ...attrs, + value, + id: htmlID, + type: "number", + style: { maxWidth: "4rem" }, + oninput: (e) => { + setter(e.target.value); + e.target.previousElementSibling.value = e.target.value; + }, + }), + ] + ), + ]), + ]); + break; + case "combo": + element = $el("tr", [ + labelCell, + $el("td", [ + $el( + "select", + { + oninput: (e) => { + setter(e.target.value); + }, + }, + (typeof options === "function" ? options(value) : options || []).map((opt) => { + if (typeof opt === "string") { + opt = { text: opt }; + } + const v = opt.value ?? opt.text; + return $el("option", { + value: v, + textContent: opt.text, + selected: value + "" === v + "", + }); + }) + ), + ]), + ]); + break; + case "text": + default: + if (type !== "text") { + console.warn(`Unsupported setting type '${type}, defaulting to text`); + } + + element = $el("tr", [ + labelCell, + $el("td", [ + $el("input", { + value, + id: htmlID, + oninput: (e) => { + setter(e.target.value); + }, + ...attrs, + }), + ]), + ]); + break; + } + } + if (tooltip) { + element.title = tooltip; + } + + return element; + }, + }; + + const self = this; + return { + get value() { + return self.getSettingValue(id, defaultValue); + }, + set value(v) { + self.setSettingValue(id, v); + }, + }; + } + + show() { + this.textElement.replaceChildren( + $el( + "tr", + { + style: { display: "none" }, + }, + [$el("th"), $el("th", { style: { width: "33%" } })] + ), + ...this.settings.sort((a, b) => a.name.localeCompare(b.name)).map((s) => s.render()) + ); + this.element.showModal(); + } +} diff --git a/web/scripts/ui/spinner.css b/web/scripts/ui/spinner.css new file mode 100644 index 00000000000..56da6072ee3 --- /dev/null +++ b/web/scripts/ui/spinner.css @@ -0,0 +1,34 @@ +.lds-ring { + display: inline-block; + position: relative; + width: 1em; + height: 1em; +} +.lds-ring div { + box-sizing: border-box; + display: block; + position: absolute; + width: 100%; + height: 100%; + border: 0.15em solid #fff; + border-radius: 50%; + animation: lds-ring 1.2s cubic-bezier(0.5, 0, 0.5, 1) infinite; + border-color: #fff transparent transparent transparent; +} +.lds-ring div:nth-child(1) { + animation-delay: -0.45s; +} +.lds-ring div:nth-child(2) { + animation-delay: -0.3s; +} +.lds-ring div:nth-child(3) { + animation-delay: -0.15s; +} +@keyframes lds-ring { + 0% { + transform: rotate(0deg); + } + 100% { + transform: rotate(360deg); + } +} diff --git a/web/scripts/ui/spinner.js b/web/scripts/ui/spinner.js new file mode 100644 index 00000000000..d049786f6a5 --- /dev/null +++ b/web/scripts/ui/spinner.js @@ -0,0 +1,9 @@ +import { addStylesheet } from "../utils.js"; + +addStylesheet(import.meta.url); + +export function createSpinner() { + const div = document.createElement("div"); + div.innerHTML = `
`; + return div.firstElementChild; +} diff --git a/web/scripts/ui/toggleSwitch.js b/web/scripts/ui/toggleSwitch.js new file mode 100644 index 00000000000..59597ef90e5 --- /dev/null +++ b/web/scripts/ui/toggleSwitch.js @@ -0,0 +1,60 @@ +import { $el } from "../ui.js"; + +/** + * @typedef { { text: string, value?: string, tooltip?: string } } ToggleSwitchItem + */ +/** + * Creates a toggle switch element + * @param { string } name + * @param { Array void } [opts.onChange] + */ +export function toggleSwitch(name, items, { onChange } = {}) { + let selectedIndex; + let elements; + + function updateSelected(index) { + if (selectedIndex != null) { + elements[selectedIndex].classList.remove("comfy-toggle-selected"); + } + onChange?.({ item: items[index], prev: selectedIndex == null ? undefined : items[selectedIndex] }); + selectedIndex = index; + elements[selectedIndex].classList.add("comfy-toggle-selected"); + } + + elements = items.map((item, i) => { + if (typeof item === "string") item = { text: item }; + if (!item.value) item.value = item.text; + + const toggle = $el( + "label", + { + textContent: item.text, + title: item.tooltip ?? "", + }, + $el("input", { + name, + type: "radio", + value: item.value ?? item.text, + checked: item.selected, + onchange: () => { + updateSelected(i); + }, + }) + ); + if (item.selected) { + updateSelected(i); + } + return toggle; + }); + + const container = $el("div.comfy-toggle-switch", elements); + + if (selectedIndex == null) { + elements[0].children[0].checked = true; + updateSelected(0); + } + + return container; +} diff --git a/web/scripts/ui/userSelection.css b/web/scripts/ui/userSelection.css new file mode 100644 index 00000000000..35c9d66148d --- /dev/null +++ b/web/scripts/ui/userSelection.css @@ -0,0 +1,135 @@ +.comfy-user-selection { + width: 100vw; + height: 100vh; + position: absolute; + top: 0; + left: 0; + z-index: 999; + display: flex; + align-items: center; + justify-content: center; + font-family: sans-serif; + background: linear-gradient(var(--tr-even-bg-color), var(--tr-odd-bg-color)); +} + +.comfy-user-selection-inner { + background: var(--comfy-menu-bg); + margin-top: -30vh; + padding: 20px 40px; + border-radius: 10px; + min-width: 365px; + position: relative; + box-shadow: 0 0 20px rgba(0, 0, 0, 0.3); +} + +.comfy-user-selection-inner form { + width: 100%; + display: flex; + flex-direction: column; + align-items: center; +} + +.comfy-user-selection-inner h1 { + margin: 10px 0 30px 0; + font-weight: normal; +} + +.comfy-user-selection-inner label { + display: flex; + flex-direction: column; + width: 100%; +} + +.comfy-user-selection input, +.comfy-user-selection select { + background-color: var(--comfy-input-bg); + color: var(--input-text); + border: 0; + border-radius: 5px; + padding: 5px; + margin-top: 10px; +} + +.comfy-user-selection input::placeholder { + color: var(--descrip-text); + opacity: 1; +} + +.comfy-user-existing { + width: 100%; +} + +.no-users .comfy-user-existing { + display: none; +} + +.comfy-user-selection-inner .or-separator { + margin: 10px 0; + padding: 10px; + display: block; + text-align: center; + width: 100%; + color: var(--descrip-text); +} + +.comfy-user-selection-inner .or-separator { + overflow: hidden; + text-align: center; + margin-left: -10px; +} + +.comfy-user-selection-inner .or-separator::before, +.comfy-user-selection-inner .or-separator::after { + content: ""; + background-color: var(--border-color); + position: relative; + height: 1px; + vertical-align: middle; + display: inline-block; + width: calc(50% - 20px); + top: -1px; +} + +.comfy-user-selection-inner .or-separator::before { + right: 10px; + margin-left: -50%; +} + +.comfy-user-selection-inner .or-separator::after { + left: 10px; + margin-right: -50%; +} + +.comfy-user-selection-inner section { + width: 100%; + padding: 10px; + margin: -10px; + transition: background-color 0.2s; +} + +.comfy-user-selection-inner section.selected { + background: var(--border-color); + border-radius: 5px; +} + +.comfy-user-selection-inner footer { + display: flex; + flex-direction: column; + align-items: center; + margin-top: 20px; +} + +.comfy-user-selection-inner .comfy-user-error { + color: var(--error-text); + margin-bottom: 10px; +} + +.comfy-user-button-next { + font-size: 16px; + padding: 6px 10px; + width: 100px; + display: flex; + gap: 5px; + align-items: center; + justify-content: center; +} \ No newline at end of file diff --git a/web/scripts/ui/userSelection.js b/web/scripts/ui/userSelection.js new file mode 100644 index 00000000000..f9f1ca8071a --- /dev/null +++ b/web/scripts/ui/userSelection.js @@ -0,0 +1,114 @@ +import { api } from "../api.js"; +import { $el } from "../ui.js"; +import { addStylesheet } from "../utils.js"; +import { createSpinner } from "./spinner.js"; + +export class UserSelectionScreen { + async show(users, user) { + // This will rarely be hit so move the loading to on demand + await addStylesheet(import.meta.url); + const userSelection = document.getElementById("comfy-user-selection"); + userSelection.style.display = ""; + return new Promise((resolve) => { + const input = userSelection.getElementsByTagName("input")[0]; + const select = userSelection.getElementsByTagName("select")[0]; + const inputSection = input.closest("section"); + const selectSection = select.closest("section"); + const form = userSelection.getElementsByTagName("form")[0]; + const error = userSelection.getElementsByClassName("comfy-user-error")[0]; + const button = userSelection.getElementsByClassName("comfy-user-button-next")[0]; + + let inputActive = null; + input.addEventListener("focus", () => { + inputSection.classList.add("selected"); + selectSection.classList.remove("selected"); + inputActive = true; + }); + select.addEventListener("focus", () => { + inputSection.classList.remove("selected"); + selectSection.classList.add("selected"); + inputActive = false; + select.style.color = ""; + }); + select.addEventListener("blur", () => { + if (!select.value) { + select.style.color = "var(--descrip-text)"; + } + }); + + form.addEventListener("submit", async (e) => { + e.preventDefault(); + if (inputActive == null) { + error.textContent = "Please enter a username or select an existing user."; + } else if (inputActive) { + const username = input.value.trim(); + if (!username) { + error.textContent = "Please enter a username."; + return; + } + + // Create new user + input.disabled = select.disabled = input.readonly = select.readonly = true; + const spinner = createSpinner(); + button.prepend(spinner); + try { + const resp = await api.createUser(username); + if (resp.status >= 300) { + let message = "Error creating user: " + resp.status + " " + resp.statusText; + try { + const res = await resp.json(); + if(res.error) { + message = res.error; + } + } catch (error) { + } + throw new Error(message); + } + + resolve({ username, userId: await resp.json(), created: true }); + } catch (err) { + spinner.remove(); + error.textContent = err.message ?? err.statusText ?? err ?? "An unknown error occurred."; + input.disabled = select.disabled = input.readonly = select.readonly = false; + return; + } + } else if (!select.value) { + error.textContent = "Please select an existing user."; + return; + } else { + resolve({ username: users[select.value], userId: select.value, created: false }); + } + }); + + if (user) { + const name = localStorage["Comfy.userName"]; + if (name) { + input.value = name; + } + } + if (input.value) { + // Focus the input, do this separately as sometimes browsers like to fill in the value + input.focus(); + } + + const userIds = Object.keys(users ?? {}); + if (userIds.length) { + for (const u of userIds) { + $el("option", { textContent: users[u], value: u, parent: select }); + } + select.style.color = "var(--descrip-text)"; + + if (select.value) { + // Focus the select, do this separately as sometimes browsers like to fill in the value + select.focus(); + } + } else { + userSelection.classList.add("no-users"); + input.focus(); + } + }).then((r) => { + userSelection.remove(); + return r; + }); + } +} diff --git a/web/scripts/utils.js b/web/scripts/utils.js index b18817c3cd5..d94f7538a16 100644 --- a/web/scripts/utils.js +++ b/web/scripts/utils.js @@ -1,3 +1,30 @@ +import { $el } from "./ui.js"; + +export function needLoadPrebuiltWorkflow(workflowId) { + var loaded = localStorage.getItem('PrebuiltWorkflowId' + workflowId); + if (loaded) { + return false + } else { + localStorage.setItem('PrebuiltWorkflowId' + workflowId, true); + return true + } +} + +export async function getWorkflow() { + let flow_json = null; + const queryString = window.location.search; + const urlParams = new URLSearchParams(queryString); + const workflowId = urlParams.get('workflow'); + if (workflowId && needLoadPrebuiltWorkflow(workflowId)) { + await fetch('../workflows/' + workflowId + '/' + workflowId + '.json').then( + response => { + flow_json = response.json() + } + ) + } + return flow_json; +} + // Simple date formatter const parts = { d: (d) => d.getDate(), @@ -66,6 +93,24 @@ export function applyTextReplacements(app, value) { }); } +export async function addStylesheet(urlOrFile, relativeTo) { + return new Promise((res, rej) => { + let url; + if (urlOrFile.endsWith(".js")) { + url = urlOrFile.substr(0, urlOrFile.length - 2) + "css"; + } else { + url = new URL(urlOrFile, relativeTo ?? `${window.location.protocol}//${window.location.host}`).toString(); + } + $el("link", { + parent: document.head, + rel: "stylesheet", + type: "text/css", + href: url, + onload: res, + onerror: rej, + }); + }); +} function setCookie(name, value, days) { @@ -114,4 +159,4 @@ export function getUserId() { setCookie('uid', uid, 999); } return uid ? uid : "anonymous"; -} \ No newline at end of file +} diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index e2e21164db8..678b1b8ec7a 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,6 +1,19 @@ import { api } from "./api.js" import "./domWidget.js"; +let controlValueRunBefore = false; +export function updateControlWidgetLabel(widget) { + let replacement = "after"; + let find = "before"; + if (controlValueRunBefore) { + [find, replacement] = [replacement, find] + } + widget.label = (widget.label ?? widget.name).replace(find, replacement); +} + +const IS_CONTROL_WIDGET = Symbol(); +const HAS_EXECUTED = Symbol(); + function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; let { min, max, step, round} = inputData[1]; @@ -62,10 +75,15 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando serialize: false, // Don't include this in prompt. } ); + valueControl[IS_CONTROL_WIDGET] = true; + updateControlWidgetLabel(valueControl); widgets.push(valueControl); const isCombo = targetWidget.type === "combo"; let comboFilter; + if (isCombo) { + valueControl.options.values.push("increment-wrap"); + } if (isCombo && options.addFilterList !== false) { comboFilter = node.addWidget( "string", @@ -76,10 +94,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando serialize: false, // Don't include this in prompt. } ); + updateControlWidgetLabel(comboFilter); + widgets.push(comboFilter); } - valueControl.afterQueued = () => { + const applyWidgetControl = () => { var v = valueControl.value; if (isCombo && v !== "fixed") { @@ -111,6 +131,12 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando case "increment": current_index += 1; break; + case "increment-wrap": + current_index += 1; + if ( current_index >= current_length ) { + current_index = 0; + } + break; case "decrement": current_index -= 1; break; @@ -159,6 +185,23 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando targetWidget.callback(targetWidget.value); } }; + + valueControl.beforeQueued = () => { + if (controlValueRunBefore) { + // Don't run on first execution + if (valueControl[HAS_EXECUTED]) { + applyWidgetControl(); + } + } + valueControl[HAS_EXECUTED] = true; + }; + + valueControl.afterQueued = () => { + if (!controlValueRunBefore) { + applyWidgetControl(); + } + }; + return widgets; }; @@ -224,6 +267,34 @@ function isSlider(display, app) { return (display==="slider") ? "slider" : "number" } +export function initWidgets(app) { + app.ui.settings.addSetting({ + id: "Comfy.WidgetControlMode", + name: "Widget Value Control Mode", + type: "combo", + defaultValue: "after", + options: ["before", "after"], + tooltip: "Controls when widget values are updated (randomize/increment/decrement), either before the prompt is queued or after.", + onChange(value) { + controlValueRunBefore = value === "before"; + for (const n of app.graph._nodes) { + if (!n.widgets) continue; + for (const w of n.widgets) { + if (w[IS_CONTROL_WIDGET]) { + updateControlWidgetLabel(w); + if (w.linkedWidgets) { + for (const l of w.linkedWidgets) { + updateControlWidgetLabel(l); + } + } + } + } + } + app.graph.setDirtyCanvas(true); + }, + }); +} + export const ComfyWidgets = { "INT:seed": seedWidget, "INT:noise_seed": seedWidget, @@ -233,7 +304,7 @@ export const ComfyWidgets = { let disable_rounding = app.ui.settings.getSettingValue("Comfy.DisableFloatRounding") if (precision == 0) precision = undefined; const { val, config } = getNumberDefaults(inputData, 0.5, precision, !disable_rounding); - return { widget: node.addWidget(widgetType, inputName, val, + return { widget: node.addWidget(widgetType, inputName, val, function (v) { if (config.round) { this.value = Math.round(v/config.round)*config.round; diff --git a/web/style.css b/web/style.css index 4d6df03282f..7cb37585a5d 100644 --- a/web/style.css +++ b/web/style.css @@ -82,6 +82,24 @@ body { margin: 3px 3px 3px 4px; } +.comfy-menu-hamburger { + position: fixed; + top: 10px; + z-index: 9999; + right: 10px; + width: 30px; + display: none; + gap: 8px; + flex-direction: column; + cursor: pointer; +} +.comfy-menu-hamburger div { + height: 3px; + width: 100%; + border-radius: 20px; + background-color: white; +} + .comfy-menu { font-size: 15px; position: absolute; @@ -101,6 +119,44 @@ body { box-shadow: 3px 3px 8px rgba(0, 0, 0, 0.4); } +.comfy-menu-header { + display: flex; +} + +.comfy-menu-actions { + display: flex; + gap: 3px; + align-items: center; + height: 20px; + position: relative; + top: -1px; + font-size: 22px; +} + +.comfy-menu .comfy-menu-actions button { + background-color: rgba(0, 0, 0, 0); + padding: 0; + border: none; + cursor: pointer; + font-size: inherit; +} + +.comfy-menu .comfy-menu-actions .comfy-settings-btn { + font-size: 0.6em; +} + +button.comfy-close-menu-btn { + font-size: 1em; + line-height: 12px; + color: #ccc; + position: relative; + top: -1px; +} + +.comfy-menu-queue-size { + flex: auto; +} + .comfy-menu button, .comfy-modal button { font-size: 20px; @@ -121,6 +177,7 @@ body { width: 100%; } +.comfy-btn, .comfy-menu > button, .comfy-menu-btns button, .comfy-menu .comfy-list button, @@ -133,16 +190,18 @@ body { margin-top: 2px; } +.comfy-btn:hover:not(:disabled), .comfy-menu > button:hover, .comfy-menu-btns button:hover, .comfy-menu .comfy-list button:hover, .comfy-modal button:hover, -.comfy-settings-btn:hover { +.comfy-menu-actions button:hover { filter: brightness(1.2); + will-change: transform; cursor: pointer; } -.comfy-menu span.drag-handle { +span.drag-handle { width: 10px; height: 20px; display: inline-block; @@ -158,12 +217,9 @@ body { letter-spacing: 2px; color: var(--drag-text); text-shadow: 1px 0 1px black; - position: absolute; - top: 0; - left: 0; } -.comfy-menu span.drag-handle::after { +span.drag-handle::after { content: '.. .. ..'; } @@ -209,15 +265,6 @@ body { font-size: 12px; } -button.comfy-settings-btn { - background-color: rgba(0, 0, 0, 0); - font-size: 12px; - padding: 0; - position: absolute; - right: 0; - border: none; -} - button.comfy-queue-btn { margin: 6px 0 !important; } @@ -263,7 +310,19 @@ button.comfy-queue-btn { } .comfy-menu span.drag-handle { - visibility: hidden + display: none; + } + + .comfy-menu-queue-size { + flex: unset; + } + + .comfy-menu-header { + justify-content: space-between; + } + .comfy-menu-actions { + gap: 10px; + font-size: 28px; } } @@ -314,7 +373,7 @@ dialog::backdrop { text-align: right; } -#comfy-settings-dialog button { +#comfy-settings-dialog tbody button, #comfy-settings-dialog table > button { background-color: var(--bg-color); border: 1px var(--border-color) solid; border-radius: 0; @@ -337,12 +396,33 @@ dialog::backdrop { } .comfy-table caption { + position: sticky; + top: 0; background-color: var(--bg-color); color: var(--input-text); font-size: 1rem; font-weight: bold; padding: 8px; text-align: center; + border-bottom: 1px solid var(--border-color); +} + +.comfy-table caption .comfy-btn { + position: absolute; + top: -2px; + right: 0; + bottom: 0; + cursor: pointer; + border: none; + height: 100%; + border-radius: 0; + aspect-ratio: 1/1; + user-select: none; + font-size: 20px; +} + +.comfy-table caption .comfy-btn:focus { + outline: none; } .comfy-table tr:nth-child(even) { @@ -383,11 +463,13 @@ dialog::backdrop { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; filter: brightness(95%); + will-change: transform; } .litegraph.litecontextmenu .litemenu-entry:hover:not(.disabled):not(.separator) { background-color: var(--comfy-menu-bg) !important; filter: brightness(155%); + will-change: transform; color: var(--input-text); } @@ -448,10 +530,30 @@ dialog::backdrop { color: var(--input-text); background-color: var(--comfy-input-bg); filter: brightness(80%); + will-change: transform; padding-left: 0.2em; } .litegraph.lite-search-item.generic_type { color: var(--input-text); filter: brightness(50%); + will-change: transform; +} + +@media only screen and (max-width: 450px) { + #comfy-settings-dialog .comfy-table tbody { + display: grid; + } + #comfy-settings-dialog .comfy-table tr { + display: grid; + } + #comfy-settings-dialog tr > td:first-child { + text-align: center; + border-bottom: none; + padding-bottom: 0; + } + #comfy-settings-dialog tr > td:not(:first-child) { + text-align: center; + border-top: none; + } }