diff --git a/llmfoundry/models/mosaicbert/modeling_mosaicbert.py b/llmfoundry/models/mosaicbert/modeling_mosaicbert.py index 3e4bf46832..e080d0fd96 100644 --- a/llmfoundry/models/mosaicbert/modeling_mosaicbert.py +++ b/llmfoundry/models/mosaicbert/modeling_mosaicbert.py @@ -92,7 +92,7 @@ class BertModel(BertPreTrainedModel): ``` """ - def __init__(self, config, add_pooling_layer=True): + def __init__(self, config: BertConfig, add_pooling_layer=True): super(BertModel, self).__init__(config) self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config)