Skip to content

Commit

Permalink
Revert "Not sure if this actually changes anything but it can't hurt."
Browse files Browse the repository at this point in the history
This reverts commit 34608de.
  • Loading branch information
comfyanonymous committed Aug 14, 2024
1 parent 3e52e03 commit a5af64d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
4 changes: 2 additions & 2 deletions comfy/ldm/flux/controlnet_xlabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def forward_orig(
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec.add_(self.guidance_in(timestep_embedding(guidance, 256)))
vec.add_(self.vector_in(y))
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)

ids = torch.cat((txt_ids, img_ids), dim=1)
Expand Down
21 changes: 10 additions & 11 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,14 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):

# prepare image for attention
img_modulated = self.img_norm1(img)
img_mod1.scale += 1
img_modulated = img_mod1.scale * img_modulated + img_mod1.shift
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_mod1.scale += 1
txt_modulated = txt_mod1.scale * txt_modulated + txt_mod1.shift
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
Expand All @@ -172,12 +170,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
img = img.addcmul(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift))
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift))
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)

if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
Expand Down Expand Up @@ -223,16 +221,17 @@ def __init__(

def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
mod.scale += 1
qkv, mlp = torch.split(self.linear1(mod.scale * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)

q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)

# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
x.addcmul_(mod.gate, self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)))
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
return x
Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def forward_orig(
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)))
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))

vec.add_(self.vector_in(y))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)

ids = torch.cat((txt_ids, img_ids), dim=1)
Expand Down

0 comments on commit a5af64d

Please sign in to comment.