From 181a16ddf7bcdc6c5340c5336d477a89f77be60a Mon Sep 17 00:00:00 2001 From: Anentropic Date: Sun, 23 Apr 2023 19:28:53 +0100 Subject: [PATCH] fix linear_to_conv2d_map to work with other distilbert model types --- ane_transformers/huggingface/distilbert.py | 26 ++++++++- .../huggingface/test_distilbert.py | 56 ++++++++++++++++++- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/ane_transformers/huggingface/distilbert.py b/ane_transformers/huggingface/distilbert.py index 4845c22..0f75780 100644 --- a/ane_transformers/huggingface/distilbert.py +++ b/ane_transformers/huggingface/distilbert.py @@ -2,6 +2,7 @@ # For licensing see accompanying LICENSE.md file. # Copyright (C) 2022 Apple Inc. All Rights Reserved. # +import re from ane_transformers.reference.layer_norm import LayerNormANE @@ -520,14 +521,33 @@ def forward( return ((loss, ) + output) if loss is not None else output +_INTERNAL_PROJ_RE = re.compile(r".*({})\.weight".format( + "|".join([ + "q_lin", "k_lin", "v_lin", "out_lin", "lin1", "lin2", + ])) +) +_OUTPUT_PROJ_RE = re.compile( + r".*({})\.weight".format( + "|".join( + [ + "classifier", + "pre_classifier", + "vocab_transform", + "vocab_projector", + "qa_outputs", + ] + ) + ) +) + + def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """ Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights """ for k in state_dict: - is_internal_proj = all(substr in k for substr in ['lin', '.weight']) - is_output_proj = all(substr in k - for substr in ['classifier', '.weight']) + is_internal_proj = _INTERNAL_PROJ_RE.match(k) + is_output_proj = _OUTPUT_PROJ_RE.match(k) if is_internal_proj or is_output_proj: if len(state_dict[k].shape) == 2: state_dict[k] = state_dict[k][:, :, None, None] diff --git a/ane_transformers/huggingface/test_distilbert.py b/ane_transformers/huggingface/test_distilbert.py index b14f14d..7e43a07 100644 --- a/ane_transformers/huggingface/test_distilbert.py +++ b/ane_transformers/huggingface/test_distilbert.py @@ -9,7 +9,6 @@ import logging import numpy as np import unittest -import time import torch @@ -32,6 +31,10 @@ ("This is not what I expected!", "NEGATIVE"), ]) +MASKED_LM_MODEL = 'distilbert-base-uncased' +QUESTION_ANSWERING_MODEL = 'distilbert-base-uncased-distilled-squad' +TOKEN_CLASSIFICATION_MODEL = 'elastic/distilbert-base-uncased-finetuned-conll03-english' +MULTIPLE_CHOICE_MODEL = 'Gladiator/distilbert-base-uncased_swag_mqa' class TestDistilBertForSequenceClassification(unittest.TestCase): """ @@ -191,5 +194,56 @@ def test_coreml_conversion_and_speedup(self): ) +class TestDistilBertLoadState(unittest.TestCase): + """ + Test load_state_dict compatibility. + """ + + test_params = ( + ( + MASKED_LM_MODEL, + transformers.AutoModelForMaskedLM, + ane_transformers.DistilBertForMaskedLM, + ), + ( + QUESTION_ANSWERING_MODEL, + transformers.AutoModelForQuestionAnswering, + ane_transformers.DistilBertForQuestionAnswering, + ), + ( + TOKEN_CLASSIFICATION_MODEL, + transformers.AutoModelForTokenClassification, + ane_transformers.DistilBertForTokenClassification, + ), + ( + MULTIPLE_CHOICE_MODEL, + transformers.AutoModelForMultipleChoice, + ane_transformers.DistilBertForMultipleChoice, + ), + ) + + def test_load_state(self): + for model_name, auto_model_cls, ane_model_cls in self.test_params: + with self.subTest(ane_model_cls=ane_model_cls): + try: + # Instantiate the reference model from an exemplar pre-trained + # model hosted on huggingface.co/models + reference_model = auto_model_cls.from_pretrained( + model_name, + return_dict=False, + torchscript=True, + ).eval() + except Exception as e: + raise RuntimeError( + "Failed to download reference model from huggingface.co/models!" + ) from e + logger.info("Downloaded reference model from huggingface.co/models") + + # Initialize an ANE equivalent model and restore the checkpoint + test_model = ane_model_cls(reference_model.config).eval() + test_model.load_state_dict(reference_model.state_dict()) + logger.info("Initialized and restored test model") + + if __name__ == "__main__": unittest.main()