Skip to content

Commit

Permalink
fix linear_to_conv2d_map to work with other distilbert model types
Browse files Browse the repository at this point in the history
  • Loading branch information
anentropic committed Apr 23, 2023
1 parent da64000 commit 181a16d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 4 deletions.
26 changes: 23 additions & 3 deletions ane_transformers/huggingface/distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# For licensing see accompanying LICENSE.md file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
import re

from ane_transformers.reference.layer_norm import LayerNormANE

Expand Down Expand Up @@ -520,14 +521,33 @@ def forward(
return ((loss, ) + output) if loss is not None else output


_INTERNAL_PROJ_RE = re.compile(r".*({})\.weight".format(
"|".join([
"q_lin", "k_lin", "v_lin", "out_lin", "lin1", "lin2",
]))
)
_OUTPUT_PROJ_RE = re.compile(
r".*({})\.weight".format(
"|".join(
[
"classifier",
"pre_classifier",
"vocab_transform",
"vocab_projector",
"qa_outputs",
]
)
)
)


def linear_to_conv2d_map(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
""" Unsqueeze twice to map nn.Linear weights to nn.Conv2d weights
"""
for k in state_dict:
is_internal_proj = all(substr in k for substr in ['lin', '.weight'])
is_output_proj = all(substr in k
for substr in ['classifier', '.weight'])
is_internal_proj = _INTERNAL_PROJ_RE.match(k)
is_output_proj = _OUTPUT_PROJ_RE.match(k)
if is_internal_proj or is_output_proj:
if len(state_dict[k].shape) == 2:
state_dict[k] = state_dict[k][:, :, None, None]
56 changes: 55 additions & 1 deletion ane_transformers/huggingface/test_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import logging
import numpy as np
import unittest
import time

import torch

Expand All @@ -32,6 +31,10 @@
("This is not what I expected!", "NEGATIVE"),
])

MASKED_LM_MODEL = 'distilbert-base-uncased'
QUESTION_ANSWERING_MODEL = 'distilbert-base-uncased-distilled-squad'
TOKEN_CLASSIFICATION_MODEL = 'elastic/distilbert-base-uncased-finetuned-conll03-english'
MULTIPLE_CHOICE_MODEL = 'Gladiator/distilbert-base-uncased_swag_mqa'

class TestDistilBertForSequenceClassification(unittest.TestCase):
"""
Expand Down Expand Up @@ -191,5 +194,56 @@ def test_coreml_conversion_and_speedup(self):
)


class TestDistilBertLoadState(unittest.TestCase):
"""
Test load_state_dict compatibility.
"""

test_params = (
(
MASKED_LM_MODEL,
transformers.AutoModelForMaskedLM,
ane_transformers.DistilBertForMaskedLM,
),
(
QUESTION_ANSWERING_MODEL,
transformers.AutoModelForQuestionAnswering,
ane_transformers.DistilBertForQuestionAnswering,
),
(
TOKEN_CLASSIFICATION_MODEL,
transformers.AutoModelForTokenClassification,
ane_transformers.DistilBertForTokenClassification,
),
(
MULTIPLE_CHOICE_MODEL,
transformers.AutoModelForMultipleChoice,
ane_transformers.DistilBertForMultipleChoice,
),
)

def test_load_state(self):
for model_name, auto_model_cls, ane_model_cls in self.test_params:
with self.subTest(ane_model_cls=ane_model_cls):
try:
# Instantiate the reference model from an exemplar pre-trained
# model hosted on huggingface.co/models
reference_model = auto_model_cls.from_pretrained(
model_name,
return_dict=False,
torchscript=True,
).eval()
except Exception as e:
raise RuntimeError(
"Failed to download reference model from huggingface.co/models!"
) from e
logger.info("Downloaded reference model from huggingface.co/models")

# Initialize an ANE equivalent model and restore the checkpoint
test_model = ane_model_cls(reference_model.config).eval()
test_model.load_state_dict(reference_model.state_dict())
logger.info("Initialized and restored test model")


if __name__ == "__main__":
unittest.main()

0 comments on commit 181a16d

Please sign in to comment.