Skip to content

Commit

Permalink
Make deep shrink behave like it should.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 16, 2023
1 parent 9f00a18 commit 7e3fe3a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
4 changes: 4 additions & 0 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,10 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo
h = p(h, transformer_options)

hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)

transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
Expand Down
3 changes: 3 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def set_model_attn2_output_patch(self, patch):
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")

def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")

def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")

Expand Down
8 changes: 6 additions & 2 deletions comfy_extras/nodes_model_downscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ def INPUT_TYPES(s):
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "_for_testing"

def patch(self, model, block_number, downscale_factor, start_percent, end_percent):
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item()
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item()

Expand All @@ -31,7 +32,10 @@ def output_block_patch(h, hsp, transformer_options):
return h, hsp

m = model.clone()
m.set_model_input_block_patch(input_block_patch)
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )

Expand Down

0 comments on commit 7e3fe3a

Please sign in to comment.