diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1f8100698ad..59c50541382 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -319,12 +319,21 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False mem_counter = 0 patch_counter = 0 lowvram_counter = 0 - load_completely = [] + loading = [] for n, m in self.model.named_modules(): + if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"): + loading.append((comfy.model_management.module_size(m), n, m)) + + load_completely = [] + loading.sort(reverse=True) + for x in loading: + n = x[1] + m = x[2] + module_mem = x[0] + lowvram_weight = False if not full_load and hasattr(m, "comfy_cast_weights"): - module_mem = comfy.model_management.module_size(m) if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 @@ -356,9 +365,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False wipe_lowvram_weight(m) if hasattr(m, "weight"): - mem_used = comfy.model_management.module_size(m) - mem_counter += mem_used - load_completely.append((mem_used, n, m)) + mem_counter += module_mem + load_completely.append((module_mem, n, m)) load_completely.sort(reverse=True) for x in load_completely: