From 17b1d6811823b522d9152c224873bf1dda441361 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 Jul 2024 18:15:09 -0400 Subject: [PATCH] Fix black images on hunyuan dit. --- comfy/ldm/hydit/attn_layers.py | 9 ++++++--- comfy/ldm/hydit/models.py | 8 +++++--- comfy/supported_models.py | 7 ++++++- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hydit/attn_layers.py b/comfy/ldm/hydit/attn_layers.py index 8997d4f07fc..920b8428602 100644 --- a/comfy/ldm/hydit/attn_layers.py +++ b/comfy/ldm/hydit/attn_layers.py @@ -106,12 +106,14 @@ def __init__(self, qk_norm=False, attn_drop=0.0, proj_drop=0.0, + attn_precision=None, device=None, dtype=None, operations=None, ): factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() + self.attn_precision = attn_precision self.qdim = qdim self.kdim = kdim self.num_heads = num_heads @@ -160,7 +162,7 @@ def forward(self, x, y, freqs_cis_img=None): k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2 v = v.transpose(-2, -3).contiguous() - context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True) + context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision) out = self.out_proj(context) # context.reshape - B, L1, -1 out = self.proj_drop(out) @@ -174,8 +176,9 @@ class Attention(nn.Module): """ We rename some layer names to align with flash attention """ - def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None): + def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None): super().__init__() + self.attn_precision = attn_precision self.dim = dim self.num_heads = num_heads assert self.dim % num_heads == 0, 'dim should be divisible by num_heads' @@ -207,7 +210,7 @@ def forward(self, x, freqs_cis_img=None): f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' q, k = qq, kk - x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True) + x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision) x = self.out_proj(x) x = self.proj_drop(x) diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 1b2b11cd7c2..4bd7abab6e2 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -40,6 +40,7 @@ def __init__(self, qk_norm=False, norm_type="layer", skip=False, + attn_precision=None, dtype=None, device=None, operations=None, @@ -56,7 +57,7 @@ def __init__(self, # ========================= Self-Attention ========================= self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device) - self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, dtype=dtype, device=device, operations=operations) + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) # ========================= FFN ========================= self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device) @@ -73,7 +74,7 @@ def __init__(self, # ========================= Cross-Attention ========================= self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, - qk_norm=qk_norm, dtype=dtype, device=device, operations=operations) + qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) # ========================= Skip Connection ========================= @@ -185,6 +186,7 @@ def __init__(self, learn_sigma = True, norm = "layer", log_fn: callable = print, + attn_precision=None, dtype=None, device=None, operations=None, @@ -246,7 +248,6 @@ def __init__(self, # Image embedding num_patches = self.x_embedder.num_patches - log_fn(f" Number of tokens: {num_patches}") # HUnYuanDiT Blocks self.blocks = nn.ModuleList([ @@ -258,6 +259,7 @@ def __init__(self, qk_norm=qk_norm, norm_type=self.norm, skip=layer > depth // 2, + attn_precision=attn_precision, dtype=dtype, device=device, operations=operations, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 677267b9a78..40417eb4daa 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -586,12 +586,15 @@ class HunyuanDiT(supported_models_base.BASE): "image_model": "hydit", } + unet_extra_config = { + "attn_precision": torch.float32, + } + sampling_settings = { "linear_start": 0.00085, "linear_end": 0.018, } - unet_extra_config = {} latent_format = latent_formats.SDXL vae_key_prefix = ["vae."] @@ -609,6 +612,8 @@ class HunyuanDiT1(HunyuanDiT): "image_model": "hydit1", } + unet_extra_config = {} + sampling_settings = { "linear_start" : 0.00085, "linear_end" : 0.03,