Skip to content

Commit

Permalink
Add unfinished ImageOnlyCheckpointSave node to save a SVD checkpoint.
Browse files Browse the repository at this point in the history
This node is unfinished, SVD checkpoints saved with this node will
work with ComfyUI but not with anything else.
  • Loading branch information
comfyanonymous committed Jan 18, 2024
1 parent fad02dc commit d76a04b
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 50 deletions.
9 changes: 8 additions & 1 deletion comfy/clip_vision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils import load_torch_file, transformers_convert, common_upscale
from .utils import load_torch_file, transformers_convert, common_upscale, state_dict_prefix_replace
import os
import torch
import contextlib
Expand Down Expand Up @@ -41,9 +41,13 @@ def __init__(self, json_config):
self.model.eval()

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)

def get_sd(self):
return self.model.state_dict()

def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)).float()
Expand Down Expand Up @@ -76,6 +80,9 @@ def convert_to_transformers(sd, prefix):
sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)

sd = transformers_convert(sd, prefix, "vision_model.", 48)
else:
replace_prefix = {prefix: ""}
sd = state_dict_prefix_replace(sd, replace_prefix)
return sd

def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
Expand Down
21 changes: 15 additions & 6 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,28 @@ def process_latent_in(self, latent):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)

def state_dict_for_saving(self, clip_state_dict, vae_state_dict):
clip_state_dict = self.model_config.process_clip_state_dict_for_saving(clip_state_dict)
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
if vae_state_dict is not None:
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))

unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
vae_state_dict = self.model_config.process_vae_state_dict_for_saving(vae_state_dict)

if self.get_dtype() == torch.float16:
clip_state_dict = utils.convert_sd_to(clip_state_dict, torch.float16)
vae_state_dict = utils.convert_sd_to(vae_state_dict, torch.float16)
extra_sds = map(lambda sd: utils.convert_sd_to(sd, torch.float16), extra_sds)

if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])

return {**unet_state_dict, **vae_state_dict, **clip_state_dict}
for sd in extra_sds:
unet_state_dict.update(sd)

return unet_state_dict

def set_inpaint(self):
self.inpaint_model = True
Expand Down
13 changes: 10 additions & 3 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,14 @@ def load_unet(unet_path):
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
return model

def save_checkpoint(output_path, model, clip, vae, metadata=None):
model_management.load_models_gpu([model, clip.load_model()])
sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd())
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None):
clip_sd = None
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()

model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
comfy.utils.save_torch_file(sd, output_path, metadata=metadata)
6 changes: 6 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

def process_clip_vision_state_dict_for_saving(self, state_dict):
replace_prefix = {}
if self.clip_vision_prefix is not None:
replace_prefix[""] = self.clip_vision_prefix
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model.diffusion_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
Expand Down
83 changes: 43 additions & 40 deletions comfy_extras/nodes_model_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,48 @@ def merge(self, model1, model2, **kwargs):
m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
return (m, )

def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)

metadata = {}

enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False

if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)

#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"

if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"

if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])

output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)

class CheckpointSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
Expand All @@ -137,46 +179,7 @@ def INPUT_TYPES(s):
CATEGORY = "advanced/model_merging"

def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
prompt_info = ""
if prompt is not None:
prompt_info = json.dumps(prompt)

metadata = {}

enable_modelspec = True
if isinstance(model.model, comfy.model_base.SDXL):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
elif isinstance(model.model, comfy.model_base.SDXLRefiner):
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
else:
enable_modelspec = False

if enable_modelspec:
metadata["modelspec.sai_model_spec"] = "1.0.0"
metadata["modelspec.implementation"] = "sgm"
metadata["modelspec.title"] = "{} {}".format(filename, counter)

#TODO:
# "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
# "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
# "v2-inpainting"

if model.model.model_type == comfy.model_base.ModelType.EPS:
metadata["modelspec.predict_key"] = "epsilon"
elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION:
metadata["modelspec.predict_key"] = "v"

if not args.disable_metadata:
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])

output_checkpoint = f"{filename}_{counter:05}_.safetensors"
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)

comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}

class CLIPSave:
Expand Down
17 changes: 17 additions & 0 deletions comfy_extras/nodes_video_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import comfy.utils
import comfy.sd
import folder_paths
import comfy_extras.nodes_model_merging


class ImageOnlyCheckpointLoader:
Expand Down Expand Up @@ -78,10 +79,26 @@ def linear_cfg(args):
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )

class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave):
CATEGORY = "_for_testing"

@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip_vision": ("CLIP_VISION",),
"vae": ("VAE",),
"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}

def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}

NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
"ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand Down

0 comments on commit d76a04b

Please sign in to comment.