From e9589d6d9246d1ce5a810be1507ead39fff50e04 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 11 Aug 2024 08:50:34 -0400 Subject: [PATCH] Add a way to set model dtype and ops from load_checkpoint_guess_config. --- comfy/model_base.py | 27 ++++++++++++++++++++++++--- comfy/sd.py | 13 +++++++++---- comfy/supported_models_base.py | 19 +++++++++++++++++++ 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index cb69496493e..830bcc68c6f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -77,10 +95,13 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod self.device = device if not unet_config.get("disable_unet_model_creation", False): - if self.manual_cast_dtype is not None: - operations = comfy.ops.manual_cast + if model_config.custom_operations is None: + if self.manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops.disable_weight_init else: - operations = comfy.ops.disable_weight_init + operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) diff --git a/comfy/sd.py b/comfy/sd.py index 68917324884..ee91ad53b2c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -498,14 +498,14 @@ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.mo return (model, clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): sd = comfy.utils.load_torch_file(ckpt_path) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}): clip = None clipvision = None vae = None @@ -525,7 +525,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if weight_dtype is not None: unet_weight_dtype.append(weight_dtype) - unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + model_config.custom_operations = model_options.get("custom_operations", None) + unet_dtype = model_options.get("weight_dtype", None) + + if unet_dtype is None: + unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index bc0a7e31108..7a2152f915d 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -1,3 +1,21 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Comfy + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + import torch from . import model_base from . import utils @@ -30,6 +48,7 @@ class BASE: memory_usage_factor = 2.0 manual_cast_dtype = None + custom_operations = None @classmethod def matches(s, unet_config, state_dict=None):