diff --git a/docs/source/tutorials/sentence_transformers.mdx b/docs/source/tutorials/sentence_transformers.mdx index a88199ed1..495c0a9b3 100644 --- a/docs/source/tutorials/sentence_transformers.mdx +++ b/docs/source/tutorials/sentence_transformers.mdx @@ -94,7 +94,7 @@ You can compile CLIP models with Optimum Neuron either by using the `optimum-cli * With CLI ```bash -optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --batch_size 3 --num_channels 3 --height 224 --width 224 --task feature-extraction --library-name sentence_transformers --subfolder 0_CLIPModel clip_emb/ +optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --text_batch_size 3 --image_batch_size 1 --num_channels 3 --height 224 --width 224 --task feature-extraction --library-name sentence_transformers --subfolder 0_CLIPModel clip_emb/ ``` * With `NeuronModelForSentenceTransformers` class @@ -102,7 +102,6 @@ optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_leng ```python from optimum.neuron import NeuronModelForSentenceTransformers -# [Compile] model_id = "sentence-transformers/clip-ViT-B-32" # configs for compiling model @@ -110,22 +109,23 @@ input_shapes = { "num_channels": 3, "height": 224, "width": 224, - "batch_size": 1, + "text_batch_size": 3, + "image_batch_size": 1, "sequence_length": 64, } emb_model = NeuronModelForSentenceTransformers.from_pretrained( - model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", **input_shapes + model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", dynamic_batch_size=False, **input_shapes ) # Save locally or upload to the HuggingFace Hub -save_directory = "clip_emb" +save_directory = "clip_emb/" emb_model.save_pretrained(save_directory) ``` ### Load compiled Sentence Transformers model and run inference -``` +```python from PIL import Image from sentence_transformers import util from transformers import CLIPProcessor diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index 2e43015b5..cc659eeb6 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -145,6 +145,16 @@ def parse_args_neuronx(parser: "ArgumentParser"): type=int, help=f"Batch size {doc_input}", ) + input_group.add_argument( + "--text_batch_size", + type=int, + help=f"Batch size of text inputs {doc_input} (Only applied for multi-modal models)", + ) + input_group.add_argument( + "--image_batch_size", + type=int, + help=f"Batch size of vision inputs {doc_input} (Only applied for multi-modal models)", + ) input_group.add_argument( "--sequence_length", type=int, diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 6edcc9de3..53884a289 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -129,9 +129,11 @@ def get_input_shapes_and_config_class(task: str, args: argparse.Namespace) -> Di def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]: args = vars(args) if isinstance(args, argparse.Namespace) else args - mandatory_axes = {"batch_size", "sequence_length"} if "clip" in args.get("model", "").lower(): - mandatory_axes.update(["num_channels", "width", "height"]) + mandatory_axes = {"text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height"} + else: + mandatory_axes = {"batch_size", "sequence_length"} + if not mandatory_axes.issubset(set(args.keys())): raise AttributeError( f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given." diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index 303e56793..9340468a6 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -144,6 +144,8 @@ def __init__( compiler_type: Optional[str] = None, compiler_version: Optional[str] = None, batch_size: Optional[int] = None, + text_batch_size: Optional[int] = None, + image_batch_size: Optional[int] = None, dynamic_batch_size: bool = False, sequence_length: Optional[int] = None, num_choices: Optional[int] = None, @@ -176,6 +178,8 @@ def __init__( # To avoid using **kwargs. axes_values = { "batch_size": batch_size, + "text_batch_size": text_batch_size, + "image_batch_size": image_batch_size, "sequence_length": sequence_length, "num_choices": num_choices, "width": width, diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index fb0bcadb9..26031e702 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -167,7 +167,7 @@ def validate_model_outputs( input_shapes = {} for axis in config.mandatory_axes: input_shapes[axis] = getattr(config, axis) - if config.dynamic_batch_size is True: + if config.dynamic_batch_size is True and "batch_size" in input_shapes: input_shapes["batch_size"] *= 2 # Reference outputs diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 045589f3b..3d0808c1c 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -14,7 +14,7 @@ # limitations under the License. """Model specific Neuron configurations.""" - +import copy from typing import TYPE_CHECKING, Dict, List import torch @@ -23,6 +23,7 @@ from ...utils import ( DummyInputGenerator, DummySeq2SeqDecoderTextInputGenerator, + DummyTextInputGenerator, DummyTimestepInputGenerator, DummyVisionInputGenerator, NormalizedConfig, @@ -276,7 +277,7 @@ def outputs(self) -> List[str]: class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig): CUSTOM_MODEL_WRAPPER = SentenceTransformersCLIPNeuronWrapper ATOL_FOR_VALIDATION = 1e-3 - INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height") + INPUT_ARGS = ("text_batch_size", "image_batch_size", "sequence_length", "num_channels", "width", "height") @property def outputs(self) -> List[str]: @@ -285,6 +286,21 @@ def outputs(self) -> List[str]: def patch_model_for_export(self, model, dummy_inputs): return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys())) + def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]: + for name, axis_dim in self._axes.items(): + self._axes[name] = kwargs.pop(name, axis_dim) + + self._validate_mandatory_axes() + + other_axes = copy.deepcopy(self._axes) + text_batch_size = other_axes.pop("text_batch_size") + images_batch_size = other_axes.pop("image_batch_size") + + return [ + DummyTextInputGenerator(self.task, self._normalized_config, batch_size=text_batch_size, **other_axes), + DummyVisionInputGenerator(self.task, self._normalized_config, batch_size=images_batch_size, **other_axes), + ] + @register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") class UNetNeuronConfig(VisionNeuronConfig): diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 8f7e30eca..2330ca0ba 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -108,9 +108,13 @@ def _setup(self, model_args: Dict): model_args.pop("model_arch") model_args.pop("dynamic_batch_size", None) - model_id = ( - self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] - ) + if model_arch in self.ARCH_MODEL_MAP: + model_id = self.ARCH_MODEL_MAP[model_arch] + elif model_arch in SENTENCE_TRANSFORMERS_MODEL_NAMES: + model_id = SENTENCE_TRANSFORMERS_MODEL_NAMES[model_arch] + else: + MODEL_NAMES[model_arch] + set_seed(SEED) neuron_model = self.NEURON_MODEL_CLASS.from_pretrained( model_id, **model_args, export=True, dynamic_batch_size=dynamic_batch_size, **self.STATIC_INPUTS_SHAPES diff --git a/tests/inference/test_modeling.py b/tests/inference/test_modeling.py index 1987cb65d..42cbb2152 100644 --- a/tests/inference/test_modeling.py +++ b/tests/inference/test_modeling.py @@ -20,7 +20,8 @@ import torch from huggingface_hub.constants import default_cache_path from parameterized import parameterized -from sentence_transformers import SentenceTransformer +from PIL import Image +from sentence_transformers import SentenceTransformer, util from transformers import ( AutoModel, AutoModelForMaskedLM, @@ -29,6 +30,7 @@ AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, + CLIPProcessor, PretrainedConfig, set_seed, ) @@ -340,7 +342,7 @@ class NeuronModelForSentenceTransformersIntegrationTest(NeuronModelTestMixin): ATOL_FOR_VALIDATION = 1e-2 SUPPORTED_ARCHITECTURES = ["transformer", "clip"] - @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + @parameterized.expand(["transformer"], skip_on_empty=True) @requires_neuronx def test_sentence_transformers_dyn_bs(self, model_arch): # Neuron model with dynamic batching @@ -393,6 +395,41 @@ def test_sentence_transformers_dyn_bs(self, model_arch): gc.collect() + @parameterized.expand(["clip"], skip_on_empty=True) + @requires_neuronx + def test_sentence_transformers_clip(self, model_arch): + + # Neuron model with dynamic batching + model_id = SENTENCE_TRANSFORMERS_MODEL_NAMES[model_arch] + input_shapes = { + "num_channels": 3, + "height": 224, + "width": 224, + "text_batch_size": 3, + "image_batch_size": 1, + "sequence_length": 16, + } + + neuron_model = self.NEURON_MODEL_CLASS.from_pretrained( + model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", **input_shapes + ) + self.assertIsInstance(neuron_model.model, torch.jit._script.ScriptModule) + self.assertIsInstance(neuron_model.config, PretrainedConfig) + + texts = ["Two dogs in the snow", "A cat on a table", "A picture of London at night"] + util.http_get( + "https://github.com/UKPLab/sentence-transformers/raw/master/examples/applications/image-search/two_dogs_in_snow.jpg", + "two_dogs_in_snow.jpg", + ) + + processor = CLIPProcessor.from_pretrained(model_id, subfolder="0_CLIPModel") + inputs = processor(text=texts, images=Image.open("two_dogs_in_snow.jpg"), return_tensors="pt", padding=True) + outputs = neuron_model(**inputs) + self.assertIn("image_embeds", outputs) + self.assertIn("text_embeds", outputs) + + gc.collect() + @is_inferentia_test class NeuronModelForMaskedLMIntegrationTest(NeuronModelTestMixin):