Skip to content

Commit

Permalink
some tests and modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Jul 9, 2024
1 parent 70cec4b commit 281c84e
Show file tree
Hide file tree
Showing 6 changed files with 2,082 additions and 1,585 deletions.
11 changes: 11 additions & 0 deletions docs/source/package_reference/modeling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,20 @@ The following Neuron model classes are available for computer vision tasks.

## Audio

The following auto classes are available for the following audio tasks.

### NeuronModelForAudioClassification
[[autodoc]] modeling.NeuronModelForAudioClassification

### NeuronModelForCTC
[[autodoc]] modeling.NeuronModelForCTC

### NeuronModelForAudioFrameClassification
[[autodoc]] modeling.NeuronModelForAudioFrameClassification

### NeuronModelForAudioXVector
[[autodoc]] modeling.NeuronModelForAudioXVector

## Stable Diffusion

The following Neuron model classes are available for stable diffusion tasks.
Expand Down
4 changes: 4 additions & 0 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"NeuronModelForSemanticSegmentation",
"NeuronModelForObjectDetection",
"NeuronModelForCTC",
"NeuronModelForAudioClassification",
"NeuronModelForAudioFrameClassification",
],
"modeling_diffusion": [
"NeuronStableDiffusionPipelineBase",
Expand Down Expand Up @@ -84,6 +86,8 @@
NeuronModelForSentenceTransformers,
NeuronModelForSequenceClassification,
NeuronModelForTokenClassification,
NeuronModelForAudioClassification,
NeuronModelForAudioFrameClassification,
)
from .modeling_decoder import NeuronDecoderModel
from .modeling_diffusion import (
Expand Down
180 changes: 171 additions & 9 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
AutoModelForSemanticSegmentation,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoModelForAudioClassification,
AutoModelForAudioFrameClassification,
AutoModelForAudioXVector,
)
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import (
Expand Down Expand Up @@ -104,14 +107,14 @@
Args:
pixel_values (`Union[torch.Tensor, None]` of shape `({0})`, defaults to `None`):
Pixel values corresponding to the images in the current batch.
Pixel values can be obtained from encoded images using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor).
Pixel values can be obtained from encoded images using [`AutoImageProcessor`](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoImageProcessor).
"""

NEURON_AUDIO_INPUTS_DOCSTRING = r"""
Args:
input_values (`torch.Tensor` of shape `({0})`):
Float values of input raw speech waveform..
Input values can be obtained from audio file loaded into an array using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor).
Input values can be obtained from audio file loaded into an array using [`AutoProcessor`](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoProcessor).
"""

FEATURE_EXTRACTION_EXAMPLE = r"""
Expand Down Expand Up @@ -556,7 +559,7 @@ def forward(
>>> from optimum.neuron import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> num_choices = 4
>>> first_sentence = ["Members of the procession walk down the street holding small horn brass instruments."] * num_choices
Expand Down Expand Up @@ -868,6 +871,154 @@ def forward(
return ModelOutput(logits=logits, pred_boxes=pred_boxes, last_hidden_state=last_hidden_state)


AUDIO_CLASSIFICATION_EXAMPLE = r"""
Example of audio classification:
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> logits = model(**inputs).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
```
Example using `transformers.pipeline`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.neuron import {model_class}
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> ac = pipeline("audio-classification", model=model, feature_extractor=feature_extractor)
>>> pred = ac(dataset[0]["audio"]["array"])
```
"""


@add_start_docstrings(
"""
Neuron Model with an audio classification head.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForAudioClassification(NeuronTracedModel):
"""
Neuron Model for audio-classification, with a sequence classification head on top (a linear layer over the pooled output) for tasks like
SUPERB Keyword Spotting.
"""

auto_model_class = AutoModelForAudioClassification

@add_start_docstrings_to_model_forward(
NEURON_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ AUDIO_CLASSIFICATION_EXAMPLE.format(
processor_class=_GENERIC_PROCESSOR,
model_class="NeuronModelForAudioClassification",
checkpoint="Jingya/wav2vec2-large-960h-lv60-self-neuronx-audio-classification",
)
)
def forward(
self,
input_values: torch.Tensor,
**kwargs,
):
neuron_inputs = {"input_values": input_values}

# run inference
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs) # shape: [batch_size, num_labels]
outputs = self.remove_padding(
outputs, dims=[0], indices=[input_values.shape[0]]
) # Remove padding on batch_size(0)

logits = outputs[0]

return SequenceClassifierOutput(logits=logits)


AUDIO_FRAME_CLASSIFICATION_EXAMPLE = r"""
Example of audio frame classification:
```python
>>> from transformers import {processor_class}
>>> from optimum.neuron import {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
>>> logits = model(**inputs).logits
>>> probabilities = torch.sigmoid(logits[0])
>>> labels = (probabilities > 0.5).long()
>>> labels[0].tolist()
```
"""


@add_start_docstrings(
"""
Neuron Model with an audio frame classification head.
""",
NEURON_MODEL_START_DOCSTRING,
)
class NeuronModelForAudioFrameClassification(NeuronTracedModel):
"""
Neuron Model with a frame classification head on top for tasks like Speaker Diarization.
"""

auto_model_class = AutoModelForAudioFrameClassification

@add_start_docstrings_to_model_forward(
NEURON_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ AUDIO_FRAME_CLASSIFICATION_EXAMPLE.format(
processor_class=_GENERIC_PROCESSOR,
model_class="NeuronModelForAudioFrameClassification",
checkpoint="Jingya/wav2vec2-base-superb-sd-neuronx",
)
)
def forward(
self,
input_values: torch.Tensor,
**kwargs,
):
neuron_inputs = {"input_values": input_values}

# run inference
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs) # shape: [batch_size, num_labels]
outputs = self.remove_padding(
outputs, dims=[0], indices=[input_values.shape[0]]
) # Remove padding on batch_size(0)

logits = outputs[0]

return TokenClassifierOutput(logits=logits)


CTC_EXAMPLE = r"""
Example of CTC:
Expand All @@ -882,18 +1033,28 @@ def forward(
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> input_shapes = {"batch_size": 1, "audio_sequence_length": 100000}
>>> compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"}
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True, **input_shapes, **compiler_args)
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> logits = model(**inputs).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = processor.batch_decode(predicted_ids)
```
Example using `transformers.pipeline`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.neuron import {model_class}
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> asr = pipeline("automatic-speech-recognition", model=model, feature_extractor=processor.feature_extractor, tokenizer=processor.tokenizer)
"""


Expand All @@ -909,13 +1070,14 @@ class NeuronModelForCTC(NeuronTracedModel):
"""

auto_model_class = AutoModelForCTC
main_input_name = "input_values"

@add_start_docstrings_to_model_forward(
NEURON_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ CTC_EXAMPLE.format(
processor_class=_GENERIC_PROCESSOR,
model_class="NeuronModelForCTC",
checkpoint="facebook/wav2vec2-large-960h-lv60-self",
checkpoint="Jingya/wav2vec2-large-960h-lv60-self-neuronx-ctc",
)
)
def forward(
Expand Down
16 changes: 16 additions & 0 deletions optimum/neuron/pipelines/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
TextClassificationPipeline,
TextGenerationPipeline,
TokenClassificationPipeline,
AutomaticSpeechRecognitionPipeline,
AudioClassificationPipeline,
)
from transformers import pipeline as transformers_pipeline
from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
Expand All @@ -53,6 +55,8 @@
NeuronModelForSentenceTransformers,
NeuronModelForSequenceClassification,
NeuronModelForTokenClassification,
NeuronModelForCTC,
NeuronModelForAudioClassification,
)


Expand Down Expand Up @@ -114,6 +118,18 @@
"default": "apple/deeplabv3-mobilevit-small",
"type": "image",
},
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"class": (NeuronModelForCTC,),
"default": "facebook/wav2vec2-large-960h-lv60-self",
"type": "multimodal",
},
"audio-classification": {
"impl": AudioClassificationPipeline,
"class": (NeuronModelForAudioClassification,),
"default": "facebook/wav2vec2-large-960h-lv60-self",
"type": "audio",
},
}


Expand Down
1 change: 1 addition & 0 deletions tests/inference/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-ip2p": "asntr/tiny-stable-diffusion-pix2pix-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
"xlm-roberta": "hf-internal-testing/tiny-xlm-roberta",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
Expand Down
Loading

0 comments on commit 281c84e

Please sign in to comment.