Skip to content

Commit

Permalink
Fix OOMs happening in some cases.
Browse files Browse the repository at this point in the history
A cloned model patcher sometimes reported a model was loaded on a device
when it wasn't.
  • Loading branch information
comfyanonymous committed Aug 6, 2024
1 parent de17a97 commit b334605
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
1 change: 1 addition & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device

if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None:
Expand Down
2 changes: 1 addition & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def model_memory(self):
return self.model.model_size()

def model_memory_required(self, device):
if device == self.model.current_device:
if device == self.model.current_loaded_device():
return 0
else:
return self.model_memory()
Expand Down
23 changes: 14 additions & 9 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
return model_options

class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
self.model = model
if not hasattr(self.model, 'device'):
logging.info("Model doesn't have a device attribute.")
self.model.device = offload_device
elif self.model.device is None:
self.model.device = offload_device

self.patches = {}
self.backup = {}
self.object_patches = {}
Expand All @@ -75,11 +81,6 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device

self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.lowvram_patch_counter = 0
Expand All @@ -92,7 +93,7 @@ def model_size(self):
return self.size

def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
Expand Down Expand Up @@ -302,7 +303,7 @@ def patch_model(self, device_to=None, patch_weights=True):

if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
self.model.device = device_to

return self.model

Expand Down Expand Up @@ -355,6 +356,7 @@ def __call__(self, weight):

self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
self.model.device = device_to
return self.model

def calculate_weight(self, patches, weight, key):
Expand Down Expand Up @@ -551,10 +553,13 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):

if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
self.model.device = device_to

keys = list(self.object_patches_backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])

self.object_patches_backup.clear()

def current_loaded_device(self):
return self.model.device
2 changes: 1 addition & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.debug("left over keys: {}".format(left_over))

if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
Expand Down

0 comments on commit b334605

Please sign in to comment.