Skip to content

Commit

Permalink
Code refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 8, 2023
1 parent 2a23ba0 commit a527d0c
Showing 1 changed file with 10 additions and 12 deletions.
22 changes: 10 additions & 12 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def __init__(self, dim):
def forward(self, t):
return timestep_embedding(t, self.dim)

def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
h += ctrl
return h

class UNetModel(nn.Module):
"""
Expand Down Expand Up @@ -617,25 +623,17 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
h = apply_control(h, control, 'input')
hs.append(h)

transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
h = apply_control(h, control, 'middle')

for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
if control is not None and 'output' in control and len(control['output']) > 0:
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
h = apply_control(h, control, 'output')

if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
Expand Down

0 comments on commit a527d0c

Please sign in to comment.