Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Feb 21, 2024
1 parent e93c058 commit 3dbe4e4
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 16 deletions.
12 changes: 6 additions & 6 deletions docs/source/tutorials/sentence_transformers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -94,38 +94,38 @@ You can compile CLIP models with Optimum Neuron either by using the `optimum-cli
* With CLI
```bash
optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --batch_size 3 --num_channels 3 --height 224 --width 224 --task feature-extraction --library-name sentence_transformers --subfolder 0_CLIPModel clip_emb/
optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --text_batch_size 3 --image_batch_size 1 --num_channels 3 --height 224 --width 224 --task feature-extraction --library-name sentence_transformers --subfolder 0_CLIPModel clip_emb/
```
* With `NeuronModelForSentenceTransformers` class
```python
from optimum.neuron import NeuronModelForSentenceTransformers
# [Compile]
model_id = "sentence-transformers/clip-ViT-B-32"

# configs for compiling model
input_shapes = {
"num_channels": 3,
"height": 224,
"width": 224,
"batch_size": 1,
"text_batch_size": 3,
"image_batch_size": 1,
"sequence_length": 64,
}

emb_model = NeuronModelForSentenceTransformers.from_pretrained(
model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", **input_shapes
model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", dynamic_batch_size=False, **input_shapes
)

# Save locally or upload to the HuggingFace Hub
save_directory = "clip_emb"
save_directory = "clip_emb/"
emb_model.save_pretrained(save_directory)
```

### Load compiled Sentence Transformers model and run inference

```
```python
from PIL import Image
from sentence_transformers import util
from transformers import CLIPProcessor
Expand Down
10 changes: 10 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=int,
help=f"Batch size {doc_input}",
)
input_group.add_argument(
"--text_batch_size",
type=int,
help=f"Batch size of text inputs {doc_input} (Only applied for multi-modal models)",
)
input_group.add_argument(
"--image_batch_size",
type=int,
help=f"Batch size of vision inputs {doc_input} (Only applied for multi-modal models)",
)
input_group.add_argument(
"--sequence_length",
type=int,
Expand Down
6 changes: 4 additions & 2 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ def get_input_shapes_and_config_class(task: str, args: argparse.Namespace) -> Di

def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
mandatory_axes = {"batch_size", "sequence_length"}
if "clip" in args.get("model", "").lower():
mandatory_axes.update(["num_channels", "width", "height"])
mandatory_axes = {"text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height"}
else:
mandatory_axes = {"batch_size", "sequence_length"}

if not mandatory_axes.issubset(set(args.keys())):
raise AttributeError(
f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given."
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def __init__(
compiler_type: Optional[str] = None,
compiler_version: Optional[str] = None,
batch_size: Optional[int] = None,
text_batch_size: Optional[int] = None,
image_batch_size: Optional[int] = None,
dynamic_batch_size: bool = False,
sequence_length: Optional[int] = None,
num_choices: Optional[int] = None,
Expand Down Expand Up @@ -176,6 +178,8 @@ def __init__(
# To avoid using **kwargs.
axes_values = {
"batch_size": batch_size,
"text_batch_size": text_batch_size,
"image_batch_size": image_batch_size,
"sequence_length": sequence_length,
"num_choices": num_choices,
"width": width,
Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def validate_model_outputs(
input_shapes = {}
for axis in config.mandatory_axes:
input_shapes[axis] = getattr(config, axis)
if config.dynamic_batch_size is True:
if config.dynamic_batch_size is True and "batch_size" in input_shapes:
input_shapes["batch_size"] *= 2

# Reference outputs
Expand Down
20 changes: 18 additions & 2 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Model specific Neuron configurations."""


import copy
from typing import TYPE_CHECKING, Dict, List

import torch
Expand All @@ -23,6 +23,7 @@
from ...utils import (
DummyInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
NormalizedConfig,
Expand Down Expand Up @@ -276,7 +277,7 @@ def outputs(self) -> List[str]:
class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig):
CUSTOM_MODEL_WRAPPER = SentenceTransformersCLIPNeuronWrapper
ATOL_FOR_VALIDATION = 1e-3
INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height")
INPUT_ARGS = ("text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height")

@property
def outputs(self) -> List[str]:
Expand All @@ -285,6 +286,21 @@ def outputs(self) -> List[str]:
def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))

def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
for name, axis_dim in self._axes.items():
self._axes[name] = kwargs.pop(name, axis_dim)

self._validate_mandatory_axes()

other_axes = copy.deepcopy(self._axes)
text_batch_size = other_axes.pop("text_batch_size")
images_batch_size = other_axes.pop("image_batch_size")

return [
DummyTextInputGenerator(self.task, self._normalized_config, batch_size=text_batch_size, **other_axes),
DummyVisionInputGenerator(self.task, self._normalized_config, batch_size=images_batch_size, **other_axes),
]


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
class UNetNeuronConfig(VisionNeuronConfig):
Expand Down
10 changes: 7 additions & 3 deletions tests/inference/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,13 @@ def _setup(self, model_args: Dict):
model_args.pop("model_arch")
model_args.pop("dynamic_batch_size", None)

model_id = (
self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch]
)
if model_arch in self.ARCH_MODEL_MAP:
model_id = self.ARCH_MODEL_MAP[model_arch]
elif model_arch in SENTENCE_TRANSFORMERS_MODEL_NAMES:
model_id = SENTENCE_TRANSFORMERS_MODEL_NAMES[model_arch]
else:
MODEL_NAMES[model_arch]

set_seed(SEED)
neuron_model = self.NEURON_MODEL_CLASS.from_pretrained(
model_id, **model_args, export=True, dynamic_batch_size=dynamic_batch_size, **self.STATIC_INPUTS_SHAPES
Expand Down
41 changes: 39 additions & 2 deletions tests/inference/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch
from huggingface_hub.constants import default_cache_path
from parameterized import parameterized
from sentence_transformers import SentenceTransformer
from PIL import Image
from sentence_transformers import SentenceTransformer, util
from transformers import (
AutoModel,
AutoModelForMaskedLM,
Expand All @@ -29,6 +30,7 @@
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
CLIPProcessor,
PretrainedConfig,
set_seed,
)
Expand Down Expand Up @@ -340,7 +342,7 @@ class NeuronModelForSentenceTransformersIntegrationTest(NeuronModelTestMixin):
ATOL_FOR_VALIDATION = 1e-2
SUPPORTED_ARCHITECTURES = ["transformer", "clip"]

@parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True)
@parameterized.expand(["transformer"], skip_on_empty=True)
@requires_neuronx
def test_sentence_transformers_dyn_bs(self, model_arch):
# Neuron model with dynamic batching
Expand Down Expand Up @@ -393,6 +395,41 @@ def test_sentence_transformers_dyn_bs(self, model_arch):

gc.collect()

@parameterized.expand(["clip"], skip_on_empty=True)
@requires_neuronx
def test_sentence_transformers_clip(self, model_arch):

# Neuron model with dynamic batching
model_id = SENTENCE_TRANSFORMERS_MODEL_NAMES[model_arch]
input_shapes = {
"num_channels": 3,
"height": 224,
"width": 224,
"text_batch_size": 3,
"image_batch_size": 1,
"sequence_length": 16,
}

neuron_model = self.NEURON_MODEL_CLASS.from_pretrained(
model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", **input_shapes
)
self.assertIsInstance(neuron_model.model, torch.jit._script.ScriptModule)
self.assertIsInstance(neuron_model.config, PretrainedConfig)

texts = ["Two dogs in the snow", "A cat on a table", "A picture of London at night"]
util.http_get(
"https://github.com/UKPLab/sentence-transformers/raw/master/examples/applications/image-search/two_dogs_in_snow.jpg",
"two_dogs_in_snow.jpg",
)

processor = CLIPProcessor.from_pretrained(model_id, subfolder="0_CLIPModel")
inputs = processor(text=texts, images=Image.open("two_dogs_in_snow.jpg"), return_tensors="pt", padding=True)
outputs = neuron_model(**inputs)
self.assertIn("image_embeds", outputs)
self.assertIn("text_embeds", outputs)

gc.collect()


@is_inferentia_test
class NeuronModelForMaskedLMIntegrationTest(NeuronModelTestMixin):
Expand Down

0 comments on commit 3dbe4e4

Please sign in to comment.