Skip to content

Commit

Permalink
Fix issue with gligen.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 18, 2023
1 parent d6e4b34 commit b80c327
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
20 changes: 11 additions & 9 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import xformers.ops

from comfy.cli_args import args
import comfy.ops

# CrossAttn precision handling
if args.dont_upcast_attention:
print("disabling upcasting of attention")
Expand Down Expand Up @@ -51,7 +53,7 @@ def init_(tensor):

# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)

Expand All @@ -61,7 +63,7 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
Expand Down Expand Up @@ -147,7 +149,7 @@ def forward(self, x):


class CrossAttentionBirchSan(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
Expand Down Expand Up @@ -244,7 +246,7 @@ def forward(self, x, context=None, value=None, mask=None):


class CrossAttentionDoggettx(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
Expand Down Expand Up @@ -342,7 +344,7 @@ def forward(self, x, context=None, value=None, mask=None):
return self.to_out(r2)

class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
Expand Down Expand Up @@ -398,7 +400,7 @@ def forward(self, x, context=None, value=None, mask=None):

class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops):
super().__init__()
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads.")
Expand Down Expand Up @@ -449,7 +451,7 @@ def forward(self, x, context=None, value=None, mask=None):
return self.to_out(out)

class CrossAttentionPytorch(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=None):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
Expand Down Expand Up @@ -507,7 +509,7 @@ def forward(self, x, context=None, value=None, mask=None):

class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
disable_self_attn=False, dtype=None, device=None, operations=None):
disable_self_attn=False, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
Expand Down Expand Up @@ -647,7 +649,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=None):
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
Expand Down
6 changes: 3 additions & 3 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
Expand Down Expand Up @@ -106,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=None):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
Expand Down Expand Up @@ -156,7 +156,7 @@ def __init__(
down=False,
dtype=None,
device=None,
operations=None
operations=comfy.ops
):
super().__init__()
self.channels = channels
Expand Down

0 comments on commit b80c327

Please sign in to comment.