From e0cbf7dea2531603ecd0f36dfce0657baa4310ec Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Mon, 26 Feb 2024 10:44:18 +0100 Subject: [PATCH] Gemma ONNX export & ORT support (#1714) * gemma onnx export * fix tests * fix model * fix --- .../run_image_classification.py | 1 + optimum/exporters/onnx/model_configs.py | 7 ++++ optimum/exporters/onnx/utils.py | 1 + optimum/exporters/tasks.py | 8 ++++ optimum/onnxruntime/modeling_decoder.py | 7 +++- optimum/utils/__init__.py | 1 + optimum/utils/input_generators.py | 38 +++++++++++++++++++ optimum/utils/normalized_config.py | 7 ++-- tests/exporters/exporters_utils.py | 2 + tests/onnxruntime/test_modeling.py | 13 ++++--- tests/onnxruntime/utils_onnxruntime_tests.py | 17 +++++---- 11 files changed, 85 insertions(+), 17 deletions(-) diff --git a/examples/onnxruntime/training/image-classification/run_image_classification.py b/examples/onnxruntime/training/image-classification/run_image_classification.py index c5d5aabe27..0582efad2a 100644 --- a/examples/onnxruntime/training/image-classification/run_image_classification.py +++ b/examples/onnxruntime/training/image-classification/run_image_classification.py @@ -51,6 +51,7 @@ from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments + """ Fine-tuning a 🤗 Transformers model for image classification""" logger = logging.getLogger(__name__) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e5715ee5a0..0e37259f7b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -37,6 +37,7 @@ DummyVisionEncoderDecoderPastKeyValuesGenerator, DummyVisionInputGenerator, FalconDummyPastKeyValuesGenerator, + GemmaDummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, MistralDummyPastKeyValuesGenerator, NormalizedConfig, @@ -240,6 +241,12 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +class GemmaOnnxConfig(LlamaOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator + pass + + class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 4e2260e0a3..6802289203 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -69,6 +69,7 @@ MODEL_TYPES_REQUIRING_POSITION_IDS = { "codegen", "falcon", + "gemma", "gpt2", "gpt-bigcode", "gpt-neo", diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index e8e8af2bce..8dda364e4e 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -592,6 +592,14 @@ class TasksManager: onnx="FlaubertOnnxConfig", tflite="FlaubertTFLiteConfig", ), + "gemma": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + onnx="GemmaOnnxConfig", + ), "glpn": supported_tasks_mapping( "feature-extraction", "depth-estimation", diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index f752051ac7..0f3a525e9a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -334,11 +334,14 @@ def prepare_past_key_values( # Generate dummy past for the first forward if uses a merged decoder if past_key_values is None: batch_size = input_ids.shape[0] - if self.model_type in {"mistral", "llama"}: + embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads + if self.model_type == "gemma": + num_attention_heads = self.normalized_config.num_key_value_heads + embed_size_per_head = self.normalized_config.head_dim + elif self.model_type in {"gemma", "mistral", "llama"}: num_attention_heads = self.normalized_config.num_key_value_heads else: num_attention_heads = self.normalized_config.num_attention_heads - embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads dtype = constructor.float16 if self.use_fp16 else constructor.float32 diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 889edbbbd4..b4e4212179 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -63,6 +63,7 @@ DummyVisionEncoderDecoderPastKeyValuesGenerator, DummyVisionInputGenerator, FalconDummyPastKeyValuesGenerator, + GemmaDummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, MistralDummyPastKeyValuesGenerator, MultiQueryPastKeyValuesGenerator, diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 0c82808131..7f6df3e723 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -1079,6 +1079,44 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ] +class GemmaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.num_key_value_heads = normalized_config.num_key_value_heads + self.head_dim = normalized_config.head_dim + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sequence_length, + self.head_dim, + ) + return [ + ( + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + class DummySpeechT5InputGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ("output_sequence", "speaker_embeddings", "spectrogram") diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 153a2e3254..f77978985d 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -228,18 +228,20 @@ class NormalizedConfigManager: "donut-swin": NormalizedVisionConfig, "electra": NormalizedTextConfig, "encoder-decoder": NormalizedEncoderDecoderConfig, + "gemma": NormalizedTextConfigWithGQA, "gpt2": GPT2LikeNormalizedTextConfig, "gpt-bigcode": GPTBigCodeNormalizedTextConfig, "gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), "gpt-neox": NormalizedTextConfig, - "llama": NormalizedTextConfigWithGQA, "gptj": GPT2LikeNormalizedTextConfig, "imagegpt": GPT2LikeNormalizedTextConfig, + "llama": NormalizedTextConfigWithGQA, "longt5": T5LikeNormalizedTextConfig, "marian": BartLikeNormalizedTextConfig, "mbart": BartLikeNormalizedTextConfig, "mistral": NormalizedTextConfigWithGQA, "mixtral": NormalizedTextConfigWithGQA, + "mpt": MPTNormalizedTextConfig, "mt5": T5LikeNormalizedTextConfig, "m2m-100": BartLikeNormalizedTextConfig, "nystromformer": NormalizedTextConfig, @@ -255,12 +257,11 @@ class NormalizedConfigManager: "splinter": NormalizedTextConfig, "t5": T5LikeNormalizedTextConfig, "trocr": TrOCRLikeNormalizedTextConfig, - "whisper": WhisperLikeNormalizedTextConfig, "vision-encoder-decoder": NormalizedEncoderDecoderConfig, "vit": NormalizedVisionConfig, + "whisper": WhisperLikeNormalizedTextConfig, "xlm-roberta": NormalizedTextConfig, "yolos": NormalizedVisionConfig, - "mpt": MPTNormalizedTextConfig, } @classmethod diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 95b069cecb..948bb7001e 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -92,6 +92,7 @@ "fxmarty/tiny-testing-falcon-alibi": ["text-generation", "text-generation-with-past"], }, "flaubert": "hf-internal-testing/tiny-random-flaubert", + "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", @@ -217,6 +218,7 @@ "electra": "google/electra-base-generator", "encoder-decoder": "patrickvonplaten/bert2bert_cnn_daily_mail", "flaubert": "hf-internal-testing/tiny-random-flaubert", # TODO + "gemma": "google/gemma-2b", "gpt2": "gpt2", "gpt-neo": "EleutherAI/gpt-neo-125M", "gpt-neox": "EleutherAI/gpt-neox-20b", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 78a2c4b08a..0e65dd8b21 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2236,6 +2236,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "bloom", "codegen", "falcon", + "gemma", "gpt2", "gpt_bigcode", "gpt_neo", @@ -2300,7 +2301,9 @@ def test_merge_from_onnx_and_save(self, model_arch): model_id = MODEL_NAMES[model_arch] task = "text-generation-with-past" - if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"): + if task not in TasksManager.get_supported_tasks_for_model_type( + model_arch.replace("_", "-"), exporter="onnx", library_name="transformers" + ): self.skipTest("Unsupported export case") with tempfile.TemporaryDirectory() as tmpdir: @@ -3626,7 +3629,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_transformers_and_save(self, model_arch): if "text2text-generation-with-past" not in TasksManager.get_supported_tasks_for_model_type( - model_arch.replace("_", "-"), exporter="onnx" + model_arch.replace("_", "-"), exporter="onnx", library_name="transformers" ): self.skipTest("Unsupported -with-past export case") @@ -3656,7 +3659,7 @@ def test_merge_from_onnx_and_save(self, model_arch): task = "text2text-generation-with-past" if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"): - self.skipTest("Unsupported export case") + self.skipTest("Unsupported export case", library_name="transformers") model_ids = self._get_model_ids(model_arch) for model_id in model_ids: @@ -4192,7 +4195,7 @@ def _generate_random_audio_data(self): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_merge_from_transformers_and_save(self, model_arch): if "automatic-speech-recognition-with-past" not in TasksManager.get_supported_tasks_for_model_type( - model_arch.replace("_", "-"), exporter="onnx" + model_arch.replace("_", "-"), exporter="onnx", library_name="transformers" ): self.skipTest("Unsupported -with-past export case") @@ -4214,7 +4217,7 @@ def test_merge_from_onnx_and_save(self, model_arch): task = "automatic-speech-recognition-with-past" if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"): - self.skipTest("Unsupported export case") + self.skipTest("Unsupported export case", library_name="transformers") with tempfile.TemporaryDirectory() as tmpdir: main_export(model_id, tmpdir, task=task) diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 6a17092756..d444dde6ae 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -97,6 +97,7 @@ }, "falcon": "fxmarty/really-tiny-falcon-testing", "flaubert": "hf-internal-testing/tiny-random-flaubert", + "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", @@ -176,22 +177,24 @@ def _setup(self, model_args: Dict): model_arch = model_args["model_arch"] model_arch_and_params = model_args["test_name"] + model_ids = MODEL_NAMES[model_arch] + if isinstance(model_ids, dict): + model_ids = list(model_ids.keys()) + else: + model_ids = [model_ids] + # TODO: this should actually be checked in ORTModel! task = self.TASK if "use_cache" in model_args and model_args["use_cache"] is True: task = task + "-with-past" + library_name = TasksManager.infer_library_from_model(model_ids[0]) + if "use_cache" in model_args and task not in TasksManager.get_supported_tasks_for_model_type( - model_arch.replace("_", "-"), exporter="onnx" + model_arch.replace("_", "-"), exporter="onnx", library_name=library_name ): self.skipTest("Unsupported export case") - model_ids = MODEL_NAMES[model_arch] - if isinstance(model_ids, dict): - model_ids = list(model_ids.keys()) - else: - model_ids = [model_ids] - if model_arch_and_params not in self.onnx_model_dirs: self.onnx_model_dirs[model_arch_and_params] = {}