Skip to content

Commit

Permalink
Modified ControlNet/T2IAdapter get_control function to receive transf…
Browse files Browse the repository at this point in the history
…ormer_options as additional parameter, made the model_options stored in extra_args in inner_sample be a clone of the original model_options instead of same ref
  • Loading branch information
Kosinkadink committed Sep 25, 2024
1 parent d3229cb commit fd2d572
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
self.strength_type = strength_type
self.concat_mask = concat_mask

def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)

if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
Expand Down Expand Up @@ -674,10 +674,10 @@ def scale_image_to(self, width, height):
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
return width, height

def get_control(self, x_noisy, t, cond, batched_number):
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
control_prev = None
if self.previous_controlnet is not None:
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number, transformer_options)

if self.timestep_range is not None:
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
Expand Down
8 changes: 4 additions & 4 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,6 @@ def outer_calc_cond_batch(model: 'BaseModel', conds: List[List[Dict]], x_in: tor
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)

if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))

transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
Expand All @@ -295,6 +292,9 @@ def outer_calc_cond_batch(model: 'BaseModel', conds: List[List[Dict]], x_in: tor

c['transformer_options'] = transformer_options

if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)

if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
Expand Down Expand Up @@ -769,7 +769,7 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas

self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)

extra_args = {"model_options": self.model_options, "seed":seed}
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}

samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
Expand Down

0 comments on commit fd2d572

Please sign in to comment.