Skip to content

Commit

Permalink
Make it easy for models to process the unet state dict on load.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 21, 2023
1 parent 2dd5b4d commit ce67dcb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def load_model_weights(self, sd, unet_prefix=""):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)

to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
Expand Down
3 changes: 3 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def get_model(self, state_dict, prefix="", device=None):
def process_clip_state_dict(self, state_dict):
return state_dict

def process_unet_state_dict(self, state_dict):
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
Expand Down

0 comments on commit ce67dcb

Please sign in to comment.