Skip to content

Commit

Permalink
Own BertModel implementation that works with lowvram.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 26, 2024
1 parent 25b51b1 commit a9ac56f
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 44 deletions.
139 changes: 139 additions & 0 deletions comfy/text_encoders/bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device

class BertAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()

self.heads = heads
self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)


def forward(self, x, mask=None, optimized_attention=None):
q = self.query(x)
k = self.key(x)
v = self.value(x)

out = optimized_attention(q, k, v, self.heads, mask)
return out

class BertOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device)
# self.dropout = nn.Dropout(0.0)

def forward(self, x, y):
x = self.dense(x)
# hidden_states = self.dropout(hidden_states)
x = self.LayerNorm(x + y)
return x

class BertAttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.self = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)

def forward(self, x, mask, optimized_attention):
y = self.self(x, mask, optimized_attention)
return self.output(y, x)

class BertIntermediate(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device)

def forward(self, x):
x = self.dense(x)
return torch.nn.functional.gelu(x)


class BertBlock(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations)
self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations)
self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations)

def forward(self, x, mask, optimized_attention):
x = self.attention(x, mask, optimized_attention)
y = self.intermediate(x)
return self.output(y, x)

class BertEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)])

def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)

if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output

intermediate = None
for i, l in enumerate(self.layer):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate

class BertEmbeddings(torch.nn.Module):
def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.word_embeddings = torch.nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device)
self.position_embeddings = torch.nn.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device)
self.token_type_embeddings = torch.nn.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device)

self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)

def forward(self, input_tokens, token_type_ids=None):
x = self.word_embeddings(input_tokens)
x += self.position_embeddings.weight[:x.shape[1]]
if token_type_ids is not None:
x += self.token_type_embeddings(token_type_ids)
else:
x += self.token_type_embeddings.weight[0]
x = self.LayerNorm(x)
return x


class BertModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
embed_dim = config_dict["hidden_size"]
layer_norm_eps = config_dict["layer_norm_eps"]

self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)

def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))

x, i = self.encoder(x, mask, intermediate_output)
return x, i


class BertModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.bert = BertModel_(config_dict, dtype, device, operations)
self.num_layers = config_dict["num_hidden_layers"]

def get_input_embeddings(self):
return self.bert.embeddings.word_embeddings

def set_input_embeddings(self, embeddings):
self.bert.embeddings.word_embeddings = embeddings

def forward(self, *args, **kwargs):
return self.bert(*args, **kwargs)
47 changes: 3 additions & 44 deletions comfy/text_encoders/hydit.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,15 @@
from comfy import sd1_clip
from transformers import T5TokenizerFast, BertTokenizer, BertModel, modeling_utils, BertConfig
from transformers import BertTokenizer
from .spiece_tokenizer import SPieceTokenizer
from .bert import BertModel
import comfy.text_encoders.t5
import os

import torch
import contextlib

@contextlib.contextmanager
def use_comfy_ops(ops, device=None, dtype=None):
old_torch_nn_linear = torch.nn.Linear
force_device = device
force_dtype = dtype
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
if force_device is not None:
device = force_device
if force_dtype is not None:
dtype = force_dtype
return ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)

torch.nn.Linear = linear_with_dtype
try:
yield
finally:
torch.nn.Linear = old_torch_nn_linear


class RobertaWrapper(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = BertConfig(**config_dict)
with use_comfy_ops(operations, device, dtype):
with modeling_utils.no_init_weights():
self.bert = BertModel(config, add_pooling_layer=False)

self.num_layers = config.num_hidden_layers

def get_input_embeddings(self):
return self.bert.get_input_embeddings()

def set_input_embeddings(self, value):
return self.bert.set_input_embeddings(value)

def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
intermediate = None
out = self.bert(input_ids=input_tokens, output_hidden_states=intermediate_output is not None, attention_mask=attention_mask)
return out.last_hidden_state, intermediate, out.pooler_output

class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=RobertaWrapper, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)

class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
Expand Down

0 comments on commit a9ac56f

Please sign in to comment.