From 2d810b081e3e992105a58b428a70cdd70779c85a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 21 Sep 2024 01:51:51 -0400 Subject: [PATCH] Add load_controlnet_state_dict function. --- comfy/controlnet.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 61a67f3f402..ff4385b337e 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -495,8 +495,8 @@ def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) -def load_controlnet(ckpt_path, model=None, model_options={}): - controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) +def load_controlnet_state_dict(state_dict, model=None, model_options={}): + controlnet_data = state_dict if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT return load_controlnet_hunyuandit(controlnet_data, model_options=model_options) @@ -578,7 +578,7 @@ def load_controlnet(ckpt_path, model=None, model_options={}): else: net = load_t2i_adapter(controlnet_data, model_options=model_options) if net is None: - logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) + logging.error("error could not detect control model type.") return net if controlnet_config is None: @@ -633,14 +633,21 @@ class WeightsLoader(torch.nn.Module): if len(unexpected) > 0: logging.debug("unexpected controlnet keys: {}".format(unexpected)) - global_average_pooling = False - filename = os.path.splitext(ckpt_path)[0] - if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling - global_average_pooling = True - + global_average_pooling = model_options.get("global_average_pooling", False) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control +def load_controlnet(ckpt_path, model=None, model_options={}): + if "global_average_pooling" not in model_options: + filename = os.path.splitext(ckpt_path)[0] + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling + model_options["global_average_pooling"] = True + + cnet = load_controlnet_state_dict(comfy.utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options) + if cnet is None: + logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) + return cnet + class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): super().__init__(device)