Skip to content

Commit

Permalink
SD1 and SD2 clip and tokenizer code is now more similar to the SDXL one.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 27, 2023
1 parent 6ec3f12 commit e60ca69
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 30 deletions.
6 changes: 4 additions & 2 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}):

text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
for b in range(32):
for b in range(32): #TODO: clean up
for c in LORA_CLIP_MAP:
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
Expand All @@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}):

k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True
Expand Down
41 changes: 39 additions & 2 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def encode_token_weights(self, token_weight_pairs):
return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()

class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
Expand Down Expand Up @@ -342,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = next(iter(values))
return embed_out

class SD1Tokenizer:
class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
Expand Down Expand Up @@ -454,3 +454,40 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):

def untokenize(self, token_weight_pair):
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))


class SD1Tokenizer:
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))

def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
return out

def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair)


class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel):
super().__init__()
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype))

def clip_layer(self, layer_idx):
getattr(self, self.clip).clip_layer(layer_idx)

def reset_clip_layer(self):
getattr(self, self.clip).reset_clip_layer()

def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs[self.clip_name]
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
return out, pooled

def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
12 changes: 10 additions & 2 deletions comfy/sd2_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import os

class SD2ClipModel(sd1_clip.SD1ClipModel):
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
Expand All @@ -12,6 +12,14 @@ def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, la
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]

class SD2Tokenizer(sd1_clip.SD1Tokenizer):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)

class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None):
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)

class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel)
29 changes: 7 additions & 22 deletions comfy/sdxl_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import os

class SDXLClipG(sd1_clip.SD1ClipModel):
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
if layer == "penultimate":
layer="hidden"
Expand All @@ -16,14 +16,14 @@ def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate"
def load_sd(self, sd):
return super().load_sd(sd)

class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')


class SDXLTokenizer(sd1_clip.SD1Tokenizer):
class SDXLTokenizer:
def __init__(self, embedding_directory=None):
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)

def tokenize_with_weights(self, text:str, return_word_ids=False):
Expand All @@ -38,7 +38,7 @@ def untokenize(self, token_weight_pair):
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device, dtype=dtype)

Expand All @@ -63,21 +63,6 @@ def load_sd(self, sd):
else:
return self.clip_l.load_sd(sd)

class SDXLRefinerClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_g = SDXLClipG(device=device, dtype=dtype)

def clip_layer(self, layer_idx):
self.clip_g.clip_layer(layer_idx)

def reset_clip_layer(self):
self.clip_g.reset_clip_layer()

def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_g = token_weight_pairs["g"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
return g_out, g_pooled

def load_sd(self, sd):
return self.clip_g.load_sd(sd)
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
11 changes: 9 additions & 2 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,15 @@ def process_clip_state_dict(self, state_dict):
if ids.dtype == torch.float32:
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()

replace_prefix = {}
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"clip_l.": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

def clip_target(self):
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)

Expand All @@ -62,12 +69,12 @@ def model_type(self, state_dict, prefix=""):
return model_base.ModelType.EPS

def process_clip_state_dict(self, state_dict):
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
return state_dict

def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {}
replace_prefix[""] = "cond_stage_model.model."
replace_prefix["clip_h"] = "cond_stage_model.model"
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
return state_dict
Expand Down

0 comments on commit e60ca69

Please sign in to comment.