Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimizations to allow bigger images #2015

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class LatentPreviewMethod(enum.Enum):
TAESD = "taesd"

parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-cpu", action="store_true", help="To use the CPU for preview (slow).")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
Expand All @@ -99,6 +100,7 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")

parser.add_argument("--memory-estimation-multiplier", type=float, default=-1, help="Multiplier for the memory estimation.")

parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
Expand Down
13 changes: 10 additions & 3 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,21 @@ def attention_split(q, k, v, heads, mask=None):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3
if args.memory_estimation_multiplier >= 0:
modifier = args.memory_estimation_multiplier
mem_required = tensor_size * modifier
steps = 1

max_steps = q.shape[1] - 1
while (q.shape[1] % max_steps) != 0:
max_steps -= 1

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

if steps > 64:
if steps > max_steps:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
Expand Down Expand Up @@ -261,8 +266,10 @@ def attention_split(q, k, v, heads, mask=None):
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
steps += 1
while (q.shape[1] % steps) != 0 and steps < max_steps:
steps += 1
if steps > max_steps:
raise e
print("out of memory error, increasing steps and trying again", steps)
else:
Expand Down
13 changes: 11 additions & 2 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Any

from comfy import model_management
from comfy.cli_args import args
import comfy.ops

if model_management.xformers_enabled_vae():
Expand Down Expand Up @@ -165,9 +166,15 @@ def slice_attention(q, k, v):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
if args.memory_estimation_multiplier >= 0:
modifier = args.memory_estimation_multiplier
mem_required = tensor_size * modifier
steps = 1

max_steps = q.shape[1] - 1
while (q.shape[1] % max_steps) != 0:
max_steps -= 1

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))

Expand All @@ -186,8 +193,10 @@ def slice_attention(q, k, v):
break
except model_management.OOM_EXCEPTION as e:
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
steps += 1
while (q.shape[1] % steps) != 0 and steps < max_steps:
steps += 1
if steps > max_steps:
raise e
print("out of memory error, increasing steps and trying again", steps)

Expand Down
39 changes: 35 additions & 4 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
return samples

def decode(self, samples_in):
pixel_samples = None

try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
Expand All @@ -232,9 +234,23 @@ def decode(self, samples_in):
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)

except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = None

tile_size = 64
while tile_size >= 8:
overlap = tile_size // 4
print(f"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding with tile size {tile_size} and overlap {overlap}.")
try:
pixel_samples = self.decode_tiled_(samples_in, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
break
except model_management.OOM_EXCEPTION as e:
pass
tile_size -= 8

if pixel_samples is None:
raise e

pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples
Expand All @@ -246,6 +262,8 @@ def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):

def encode(self, pixel_samples):
pixel_samples = pixel_samples.movedim(-1,1)
samples = None

try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
Expand All @@ -258,8 +276,21 @@ def encode(self, pixel_samples):
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()

except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)
samples = None

tile_size = 512
while tile_size >= 64:
overlap = tile_size // 8
print(f"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding with tile size {tile_size} and overlap {overlap}.")
try:
samples = self.encode_tiled_(pixel_samples, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
break
except model_management.OOM_EXCEPTION as e:
pass
tile_size -= 64

if samples is None:
raise e

return samples

Expand Down
15 changes: 11 additions & 4 deletions latent_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from comfy.taesd.taesd import TAESD
import folder_paths
import comfy.utils
from comfy import model_management

MAX_PREVIEW_RESOLUTION = 512

Expand All @@ -18,11 +19,12 @@ def decode_latent_to_preview_image(self, preview_format, x0):
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)

class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
def __init__(self, taesd, device):
self.taesd = taesd
self.device = device

def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decode(x0[:1])[0].detach()
x_sample = self.taesd.decode(x0[:1].to(self.device))[0].detach()
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
Expand All @@ -49,6 +51,8 @@ def decode_latent_to_preview(self, x0):
def get_previewer(device, latent_format):
previewer = None
method = args.preview_method
if args.preview_cpu:
device = torch.device("cpu")
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = None
Expand All @@ -68,7 +72,7 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
previewer = TAESDPreviewerImpl(taesd, device)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))

Expand All @@ -91,7 +95,10 @@ def callback(step, x0, x, total_steps):

preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
try:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
except model_management.OOM_EXCEPTION as e:
pass
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback