From ce67dcbcdabe2edf1497e37ecf1b6f976a3ecdf6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 22:27:36 -0500 Subject: [PATCH] Make it easy for models to process the unet state dict on load. --- comfy/model_base.py | 1 + comfy/supported_models_base.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 37bf24bb8c6..772e2693493 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -121,6 +121,7 @@ def load_model_weights(self, sd, unet_prefix=""): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: print("unet missing:", m) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 88a1d7fde49..6dfae034303 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -53,6 +53,9 @@ def get_model(self, state_dict, prefix="", device=None): def process_clip_state_dict(self, state_dict): return state_dict + def process_unet_state_dict(self, state_dict): + return state_dict + def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix)