From f1a01c2c7e2678f34c72212d2a8640237f31fa43 Mon Sep 17 00:00:00 2001 From: Extraltodeus Date: Tue, 9 Jul 2024 22:20:49 +0200 Subject: [PATCH] Add sampler_pre_cfg_function (#3979) * Update samplers.py * Update model_patcher.py --- comfy/model_patcher.py | 9 +++++++++ comfy/samplers.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b949031e998..efac251ca90 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -57,6 +57,12 @@ def set_model_options_post_cfg_function(model_options, post_cfg_function, disabl model_options["disable_cfg1_optimization"] = True return model_options +def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_cfg1_optimization=False): + model_options["sampler_pre_cfg_function"] = model_options.get("sampler_pre_cfg_function", []) + [pre_cfg_function] + if disable_cfg1_optimization: + model_options["disable_cfg1_optimization"] = True + return model_options + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size @@ -130,6 +136,9 @@ def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_opti def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) + def set_model_sampler_pre_cfg_function(self, pre_cfg_function, disable_cfg1_optimization=False): + self.model_options = set_model_options_pre_cfg_function(self.model_options, pre_cfg_function, disable_cfg1_optimization) + def set_model_unet_function_wrapper(self, unet_wrapper_function: UnetWrapperFunction): self.model_options["model_function_wrapper"] = unet_wrapper_function diff --git a/comfy/samplers.py b/comfy/samplers.py index c0aa12916d3..bbf9219f7ee 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -275,6 +275,12 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option conds = [cond, uncond_] out = calc_cond_batch(model, conds, x, timestep, model_options) + + for fn in model_options.get("sampler_pre_cfg_function", []): + args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, + "input": x, "sigma": timestep, "model": model, "model_options": model_options} + out = fn(args) + return cfg_function(model, out[0], out[1], cond_scale, x, timestep, model_options=model_options, cond=cond, uncond=uncond_)