Skip to content

Commit

Permalink
Merge branch 'master' into execution_model_inversion
Browse files Browse the repository at this point in the history
  • Loading branch information
guill committed Aug 10, 2024
2 parents fd7229e + 6678d5c commit e73c917
Show file tree
Hide file tree
Showing 16 changed files with 899 additions and 135 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/pullrequest-ci-run.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Pull Request CI Workflow Runs
on:
pull_request_target:
types: [labeled]

jobs:
pr-test-stable:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
95 changes: 95 additions & 0 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Full Comfy CI Workflow Runs
on:
push:
branches:
- master
paths-ignore:
- 'app/**'
- 'input/**'
- 'output/**'
- 'notebooks/**'
- 'script_examples/**'
- '.github/**'
- 'web/**'
workflow_dispatch:

jobs:
test-stable:
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}

test-win-nightly:
strategy:
fail-fast: true
matrix:
os: [windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}

test-unix-nightly:
strategy:
fail-fast: false
matrix:
os: [macos, linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
88 changes: 73 additions & 15 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""


import torch
from enum import Enum
import math
import os
import logging
Expand All @@ -13,7 +33,7 @@
import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit

import comfy.ldm.hydit.controlnet

def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
Expand All @@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)

class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2

class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
Expand All @@ -51,6 +75,8 @@ def __init__(self, device=None):
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT

def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint
Expand Down Expand Up @@ -93,6 +119,8 @@ def copy_to(self, c):
c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type

def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
Expand All @@ -113,7 +141,10 @@ def control_merge(self, control, control_prev, output_dtype):

if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x)
x *= self.strength
if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))

if x.dtype != output_dtype:
x = x.to(output_dtype)
Expand Down Expand Up @@ -142,7 +173,7 @@ def set_extra_arg(self, argument, value=None):


class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
Expand All @@ -154,6 +185,8 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type

def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
Expand Down Expand Up @@ -191,13 +224,16 @@ def get_control(self, x_noisy, t, cond, batched_number):
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
extra = self.extra_args.copy()
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype)

def copy(self):
Expand Down Expand Up @@ -338,12 +374,8 @@ def get_models(self):
def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)

def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
def controlnet_config(sd):
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)

supported_inference_dtypes = model_config.supported_inference_dtypes

Expand All @@ -356,23 +388,49 @@ def load_controlnet_mmdit(sd):
else:
operations = comfy.ops.disable_weight_init

control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
return model_config, operations, load_device, unet_dtype, manual_cast_dtype

def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)

if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))

if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model

def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]

control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)

latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control

def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)

control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)

latent_format = comfy.latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control

def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)

if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)

Expand Down
Loading

0 comments on commit e73c917

Please sign in to comment.