Skip to content

Commit

Permalink
Add: advanced->model->ModelSamplingDiscrete node.
Browse files Browse the repository at this point in the history
This allows changing the sampling parameters of the model (eps or vpred)
or set the model to use zsnr.
  • Loading branch information
comfyanonymous committed Nov 7, 2023
1 parent d07cd44 commit 844dbf9
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 0 deletions.
17 changes: 17 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No
self.model = model
self.patches = {}
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
Expand Down Expand Up @@ -91,6 +93,9 @@ def set_model_attn2_output_patch(self, patch):
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")

def add_object_patch(self, name, obj):
self.object_patches[name] = obj

def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
Expand Down Expand Up @@ -150,6 +155,12 @@ def model_state_dict(self, filter_prefix=None):
return sd

def patch_model(self, device_to=None):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])

model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
Expand Down Expand Up @@ -290,3 +301,9 @@ def unpatch_model(self, device_to=None):
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to

keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])

self.object_patches_backup = {}
2 changes: 2 additions & 0 deletions comfy/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)

def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())

Expand Down
57 changes: 57 additions & 0 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import folder_paths
import comfy.sd
import comfy.model_sampling


def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5

class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, sampling, zsnr):
m = model.clone()

if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION

class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
pass

model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling)
return (m, )

NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,6 +1798,7 @@ def init_custom_nodes():
"nodes_freelunch.py",
"nodes_custom_sampler.py",
"nodes_hypertile.py",
"nodes_model_advanced.py",
]

for node_file in extras_files:
Expand Down

0 comments on commit 844dbf9

Please sign in to comment.