Skip to content

Commit

Permalink
Long CLIP L support for SDXL, SD3 and Flux.
Browse files Browse the repository at this point in the history
Use the *CLIPLoader nodes.
  • Loading branch information
comfyanonymous committed Sep 15, 2024
1 parent 5e68a4c commit e813abb
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 17 deletions.
12 changes: 5 additions & 7 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,8 @@ class EmptyClass:
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
else:
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
clip_target.clip = comfy.text_encoders.long_clipl.LongClipModel
clip_target.tokenizer = comfy.text_encoders.long_clipl.LongClipTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
Expand All @@ -475,10 +471,12 @@ class EmptyClass:
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer

parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)

clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
Expand Down
2 changes: 2 additions & 0 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))

def tokenize_with_weights(self, text:str, return_word_ids=False):
Expand Down Expand Up @@ -570,6 +571,7 @@ def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", cl
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)

clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))

self.dtypes = set()
Expand Down
9 changes: 6 additions & 3 deletions comfy/sdxl_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data

class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)

def tokenize_with_weights(self, text:str, return_word_ids=False):
Expand All @@ -40,7 +41,8 @@ def state_dict(self):
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])

Expand All @@ -57,7 +59,8 @@ def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled

def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
Expand Down
6 changes: 4 additions & 2 deletions comfy/text_encoders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):

class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)

def tokenize_with_weights(self, text:str, return_word_ids=False):
Expand All @@ -38,7 +39,8 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])

Expand Down
15 changes: 13 additions & 2 deletions comfy/text_encoders/long_clipl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)

class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, *args, **kwargs):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)

class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
Expand All @@ -17,3 +17,14 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)

def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options
9 changes: 6 additions & 3 deletions comfy/text_encoders/sd3_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):

class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)

Expand All @@ -42,7 +43,8 @@ def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
Expand Down Expand Up @@ -95,7 +97,8 @@ def encode_token_weights(self, token_weight_pairs):
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:,:cut_to], g_out[:,:cut_to]], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
Expand Down

0 comments on commit e813abb

Please sign in to comment.