From bd9f5607244f43f5f60ea630e3b0557887643fcc Mon Sep 17 00:00:00 2001 From: shiimizu Date: Tue, 9 Jan 2024 02:10:30 -0800 Subject: [PATCH] Support non-square resolutions in Self-Attention Guidance. --- comfy_extras/nodes_sag.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 450ac3eeacd..12aa0542480 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -51,6 +51,9 @@ def attention_basic_with_sim(q, k, v, heads, mask=None): ) return (out, sim) +def ceildiv(a, b): + return -(a // -b) + def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): # reshape and GAP the attention map _, hw1, hw2 = attn.shape @@ -58,8 +61,10 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): attn = attn.reshape(b, -1, hw1, hw2) # Global Average Pool mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold - ratio = math.ceil(math.sqrt(lh * lw / hw1)) - mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + ratio = math.sqrt(lh * lw / hw1) + new_lh = lh // ratio if lh > lw else ceildiv(lh, ratio) + new_lw = lw // ratio if lw > lh else ceildiv(lw, ratio) + mid_shape = [int(new_lh), int(new_lw)] # Reshape mask = (