Skip to content

Commit

Permalink
support phi
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Mar 9, 2024
1 parent 8f84127 commit caf2760
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 0 deletions.
13 changes: 13 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
VisionNeuronConfig,
)
from .model_wrappers import (
NoCacheModelWrapper,
SentenceTransformersCLIPNeuronWrapper,
SentenceTransformersTransformerNeuronWrapper,
T5DecoderWrapper,
Expand Down Expand Up @@ -122,6 +123,18 @@ class MobileBertNeuronConfig(BertNeuronConfig):
pass


@register_in_tasks_manager("phi", *["feature-extraction", "text-classification", "token-classification"])
class PhiNeuronConfig(ElectraNeuronConfig):
CUSTOM_MODEL_WRAPPER = NoCacheModelWrapper

@property
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


@register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS)
class RoFormerNeuronConfig(ElectraNeuronConfig):
pass
Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,16 @@ def forward(self, input_ids, pixel_values, attention_mask):
text_embeds = self.model[1:](text_embeds)

return (text_embeds, image_embeds)


class NoCacheModelWrapper(torch.nn.Module):
def __init__(self, model: "PreTrainedModel", input_names: List[str]):
super().__init__()
self.model = model
self.input_names = input_names

def forward(self, *input):
ordered_inputs = dict(zip(self.input_names, input))
outputs = self.model(use_cache=False, **ordered_inputs)

return outputs
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"flaubert": "flaubert/flaubert_small_cased",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"phi": "hf-internal-testing/tiny-random-PhiModel",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
Expand Down
1 change: 1 addition & 0 deletions tests/inference/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"phi": "hf-internal-testing/tiny-random-PhiModel",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
Expand Down

0 comments on commit caf2760

Please sign in to comment.