diff --git a/llmfoundry/models/layers/mosaicbert_layers.py b/llmfoundry/models/layers/mosaicbert_layers.py index add9b595a7..524c068963 100644 --- a/llmfoundry/models/layers/mosaicbert_layers.py +++ b/llmfoundry/models/layers/mosaicbert_layers.py @@ -49,6 +49,7 @@ from llmfoundry.models.utils.bert_padding import (index_first_axis, pad_input, unpad_input, unpad_input_only) +from llmfoundry.models.mosaicbert.configuration_mosaicbert import BertConfig try: from llmfoundry.models.layers.flash_attn_triton import \ @@ -75,7 +76,7 @@ class BertEmbeddings(nn.Module): This module ignores the `position_ids` input to the `forward` method. """ - def __init__(self, config): + def __init__(self, config: BertConfig): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, @@ -155,13 +156,13 @@ class BertUnpadSelfAttention(nn.Module): See `forward` method for additional detail. """ - def __init__(self, config): + def __init__(self, config: BertConfig): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr( config, 'embedding_size'): raise ValueError( f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention ' - f'heads ({config.num_attention_heads})') + + f'heads ({config.num_attention_heads})') self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / @@ -252,7 +253,7 @@ class BertSelfOutput(nn.Module): BERT modules. """ - def __init__(self, config): + def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, @@ -270,7 +271,7 @@ def forward(self, hidden_states: torch.Tensor, class BertUnpadAttention(nn.Module): """Chains attention, Dropout, and LayerNorm for MosaicBERT.""" - def __init__(self, config): + def __init__(self, config: BertConfig): super().__init__() self.self = BertUnpadSelfAttention(config) self.output = BertSelfOutput(config) @@ -321,7 +322,7 @@ class BertGatedLinearUnitMLP(nn.Module): parameter size, MosaicBERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. """ - def __init__(self, config): + def __init__(self, config: BertConfig): super().__init__() self.config = config self.gated_layers = nn.Linear(config.hidden_size, @@ -357,7 +358,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertLayer(nn.Module): """Composes the MosaicBERT attention and FFN blocks into a single layer.""" - def __init__(self, config): + def __init__(self, config: BertConfig): super(BertLayer, self).__init__() self.attention = BertUnpadAttention(config) self.mlp = BertGatedLinearUnitMLP(config) @@ -400,7 +401,7 @@ class BertEncoder(nn.Module): at padded tokens, and pre-computes attention biases to implement ALiBi. """ - def __init__(self, config): + def __init__(self, config: BertConfig): super().__init__() layer = BertLayer(config) self.layer = nn.ModuleList( @@ -547,7 +548,7 @@ def forward( class BertPooler(nn.Module): - def __init__(self, config): + def __init__(self, config: BertConfig): super(BertPooler, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() @@ -565,7 +566,7 @@ def forward(self, class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): + def __init__(self, config: BertConfig): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str): @@ -586,7 +587,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ################### class BertLMPredictionHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): + def __init__(self, config: BertConfig, bert_model_embedding_weights): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is