Skip to content

Commit

Permalink
Refactored WrapperExecutor code to remove need for WrapperClassExecut…
Browse files Browse the repository at this point in the history
…or (now gone), added sampler.sample wrapper (pending review, will likely keep but will see what hacks this could currently let me get rid of in ACN/ADE)
  • Loading branch information
Kosinkadink committed Sep 27, 2024
1 parent 09cbd69 commit 0f7d379
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 37 deletions.
52 changes: 19 additions & 33 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,69 +140,55 @@ def init_callbacks(cls):
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
CALC_COND_BATCH = "calc_cond_batch"
SAMPLER_SAMPLE = "sampler_sample"

@classmethod
def init_wrappers(cls):
return {
cls.OUTER_SAMPLE: {None: []},
cls.SAMPLER_SAMPLE: {None: []},
cls.CALC_COND_BATCH: {None: []},
}

class WrapperExecutor:
def __init__(self, original: Callable, wrappers: list[Callable], idx: int):
"""Handles call stack of wrappers around a function in an ordered manner."""
def __init__(self, original: Callable, class_obj: object, wrappers: list[Callable], idx: int):
self.original = original
self.class_obj = class_obj
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)

def __call__(self, *args, **kwargs):
"""Calls the next wrapper in line or original function, whichever is appropriate."""
new_executor = self._create_next_executor()
return new_executor._execute(*args, **kwargs)
return new_executor.execute(*args, **kwargs)

def _execute(self, *args, **kwargs):
def execute(self, *args, **kwargs):
"""Used to initiate executor internally - DO NOT use this if you received executor in wrapper."""
args = list(args)
kwargs = dict(kwargs)
if self.is_last:
if self.class_obj is None:
return self.original(*args, **kwargs)
return self.original(*args, **kwargs)
return self.wrappers[self.idx](self, *args, **kwargs)

def _create_next_executor(self):
def _create_next_executor(self) -> 'WrapperExecutor':
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
return WrapperExecutor(self.original, self.wrappers, new_idx)
if self.class_obj is None:
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)

@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable]):
return cls(original, wrappers, idx=0)

class WrapperClassExecutor:
def __init__(self, original: Callable, wrappers: list[Callable], idx: int):
self.original = original
self.wrappers = wrappers.copy()
self.idx = idx
self.is_last = idx == len(wrappers)

def __call__(self, class_inst, *args, **kwargs):
new_executor = self._create_next_executor()
return new_executor._execute(class_inst, *args, **kwargs)
def new_executor(cls, original: Callable, wrappers: list[Callable], idx=0):
return cls(original, class_obj=None, wrappers=wrappers, idx=idx)

def _execute(self, class_inst, *args, **kwargs):
args = list(args)
kwargs = dict(kwargs)
if self.is_last:
return self.original(*args, **kwargs)
return self.wrappers[self.idx](self, class_inst, *args, **kwargs)

def _create_next_executor(self):
new_idx = self.idx + 1
if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
return WrapperClassExecutor(self.original, self.wrappers, new_idx)

@classmethod
def new_executor(cls, original: Callable, wrappers: list[Callable]):
return cls(original, wrappers, idx=0)
def new_class_executor(cls, original: Callable, class_obj: object, wrappers: list[Callable], idx=0):
return cls(original, class_obj, wrappers, idx=idx)

class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
Expand Down
14 changes: 10 additions & 4 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Ten
outer_calc_cond_batch,
model.current_patcher.get_all_wrappers(comfy.model_patcher.WrappersMP.CALC_COND_BATCH)
)
return executor._execute(model, conds, x_in, timestep, model_options)
return executor.execute(model, conds, x_in, timestep, model_options)

def outer_calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
out_conds = []
Expand Down Expand Up @@ -771,7 +771,12 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas

extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}

samples = sampler.sample(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
executor = comfy.model_patcher.WrapperExecutor.new_class_executor(
sampler.sample,
sampler,
self.model_patcher.get_all_wrappers(comfy.model_patcher.WrappersMP.SAMPLER_SAMPLE)
)
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))

def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
Expand Down Expand Up @@ -806,11 +811,12 @@ def sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callba

try:
comfy.sampler_helpers.prepare_model_patcher(self.model_patcher, self.conds)
executor = comfy.model_patcher.WrapperClassExecutor.new_executor(
executor = comfy.model_patcher.WrapperExecutor.new_class_executor(
self.outer_sample,
self,
self.model_patcher.get_all_wrappers(comfy.model_patcher.WrappersMP.OUTER_SAMPLE)
)
output = executor._execute(self, noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
finally:
self.model_patcher.restore_hook_patches()

Expand Down

0 comments on commit 0f7d379

Please sign in to comment.