Skip to content

Commit

Permalink
pyright BertConfig typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobfulano committed Aug 17, 2023
1 parent 1cd7c2e commit 20b3df1
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions llmfoundry/models/layers/mosaicbert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

from llmfoundry.models.utils.bert_padding import (index_first_axis, pad_input,
unpad_input, unpad_input_only)
from llmfoundry.models.mosaicbert.configuration_mosaicbert import BertConfig

try:
from llmfoundry.models.layers.flash_attn_triton import \
Expand All @@ -75,7 +76,7 @@ class BertEmbeddings(nn.Module):
This module ignores the `position_ids` input to the `forward` method.
"""

def __init__(self, config):
def __init__(self, config: BertConfig):
super().__init__()
self.word_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
Expand Down Expand Up @@ -155,13 +156,13 @@ class BertUnpadSelfAttention(nn.Module):
See `forward` method for additional detail.
"""

def __init__(self, config):
def __init__(self, config: BertConfig):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, 'embedding_size'):
raise ValueError(
f'The hidden size ({config.hidden_size}) is not a multiple of the number of attention '
f'heads ({config.num_attention_heads})')
+ f'heads ({config.num_attention_heads})')

self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size /
Expand Down Expand Up @@ -252,7 +253,7 @@ class BertSelfOutput(nn.Module):
BERT modules.
"""

def __init__(self, config):
def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
Expand All @@ -270,7 +271,7 @@ def forward(self, hidden_states: torch.Tensor,
class BertUnpadAttention(nn.Module):
"""Chains attention, Dropout, and LayerNorm for MosaicBERT."""

def __init__(self, config):
def __init__(self, config: BertConfig):
super().__init__()
self.self = BertUnpadSelfAttention(config)
self.output = BertSelfOutput(config)
Expand Down Expand Up @@ -321,7 +322,7 @@ class BertGatedLinearUnitMLP(nn.Module):
parameter size, MosaicBERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`.
"""

def __init__(self, config):
def __init__(self, config: BertConfig):
super().__init__()
self.config = config
self.gated_layers = nn.Linear(config.hidden_size,
Expand Down Expand Up @@ -357,7 +358,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class BertLayer(nn.Module):
"""Composes the MosaicBERT attention and FFN blocks into a single layer."""

def __init__(self, config):
def __init__(self, config: BertConfig):
super(BertLayer, self).__init__()
self.attention = BertUnpadAttention(config)
self.mlp = BertGatedLinearUnitMLP(config)
Expand Down Expand Up @@ -400,7 +401,7 @@ class BertEncoder(nn.Module):
at padded tokens, and pre-computes attention biases to implement ALiBi.
"""

def __init__(self, config):
def __init__(self, config: BertConfig):
super().__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList(
Expand Down Expand Up @@ -547,7 +548,7 @@ def forward(

class BertPooler(nn.Module):

def __init__(self, config):
def __init__(self, config: BertConfig):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
Expand All @@ -565,7 +566,7 @@ def forward(self,

class BertPredictionHeadTransform(nn.Module):

def __init__(self, config):
def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
Expand All @@ -586,7 +587,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
###################
class BertLMPredictionHead(nn.Module):

def __init__(self, config, bert_model_embedding_weights):
def __init__(self, config: BertConfig, bert_model_embedding_weights):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
Expand Down

0 comments on commit 20b3df1

Please sign in to comment.