Skip to content

Commit

Permalink
Add a way to set model dtype and ops from load_checkpoint_guess_config.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 11, 2024
1 parent 0d82a79 commit e9589d6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
27 changes: 24 additions & 3 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
Expand Down Expand Up @@ -77,10 +95,13 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.device = device

if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
if model_config.custom_operations is None:
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
else:
operations = comfy.ops.disable_weight_init
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
Expand Down
13 changes: 9 additions & 4 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,14 @@ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.mo

return (model, clip, vae)

def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out

def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
clip = None
clipvision = None
vae = None
Expand All @@ -525,7 +525,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)

unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("weight_dtype", None)

if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)

manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)

Expand Down
19 changes: 19 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

import torch
from . import model_base
from . import utils
Expand Down Expand Up @@ -30,6 +48,7 @@ class BASE:
memory_usage_factor = 2.0

manual_cast_dtype = None
custom_operations = None

@classmethod
def matches(s, unet_config, state_dict=None):
Expand Down

0 comments on commit e9589d6

Please sign in to comment.