Skip to content

Commit

Permalink
Fix black images on hunyuan dit.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 25, 2024
1 parent 2f2f723 commit 17b1d68
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
9 changes: 6 additions & 3 deletions comfy/ldm/hydit/attn_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,14 @@ def __init__(self,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
attn_precision=None,
device=None,
dtype=None,
operations=None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.attn_precision = attn_precision
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
Expand Down Expand Up @@ -160,7 +162,7 @@ def forward(self, x, y, freqs_cis_img=None):
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
v = v.transpose(-2, -3).contiguous()

context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)

out = self.out_proj(context) # context.reshape - B, L1, -1
out = self.proj_drop(out)
Expand All @@ -174,8 +176,9 @@ class Attention(nn.Module):
"""
We rename some layer names to align with flash attention
"""
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None):
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
Expand Down Expand Up @@ -207,7 +210,7 @@ def forward(self, x, freqs_cis_img=None):
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
q, k = qq, kk

x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
x = self.out_proj(x)
x = self.proj_drop(x)

Expand Down
8 changes: 5 additions & 3 deletions comfy/ldm/hydit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self,
qk_norm=False,
norm_type="layer",
skip=False,
attn_precision=None,
dtype=None,
device=None,
operations=None,
Expand All @@ -56,7 +57,7 @@ def __init__(self,

# ========================= Self-Attention =========================
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, dtype=dtype, device=device, operations=operations)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)

# ========================= FFN =========================
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
Expand All @@ -73,7 +74,7 @@ def __init__(self,

# ========================= Cross-Attention =========================
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations)
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)

# ========================= Skip Connection =========================
Expand Down Expand Up @@ -185,6 +186,7 @@ def __init__(self,
learn_sigma = True,
norm = "layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
Expand Down Expand Up @@ -246,7 +248,6 @@ def __init__(self,

# Image embedding
num_patches = self.x_embedder.num_patches
log_fn(f" Number of tokens: {num_patches}")

# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
Expand All @@ -258,6 +259,7 @@ def __init__(self,
qk_norm=qk_norm,
norm_type=self.norm,
skip=layer > depth // 2,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
Expand Down
7 changes: 6 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,12 +586,15 @@ class HunyuanDiT(supported_models_base.BASE):
"image_model": "hydit",
}

unet_extra_config = {
"attn_precision": torch.float32,
}

sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.018,
}

unet_extra_config = {}
latent_format = latent_formats.SDXL

vae_key_prefix = ["vae."]
Expand All @@ -609,6 +612,8 @@ class HunyuanDiT1(HunyuanDiT):
"image_model": "hydit1",
}

unet_extra_config = {}

sampling_settings = {
"linear_start" : 0.00085,
"linear_end" : 0.03,
Expand Down

0 comments on commit 17b1d68

Please sign in to comment.