Skip to content

Commit

Permalink
Gemma ONNX export & ORT support (#1714)
Browse files Browse the repository at this point in the history
* gemma onnx export

* fix tests

* fix model

* fix
  • Loading branch information
fxmarty authored Feb 26, 2024
1 parent 4a5d97b commit e0cbf7d
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments



""" Fine-tuning a 🤗 Transformers model for image classification"""

logger = logging.getLogger(__name__)
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
FalconDummyPastKeyValuesGenerator,
GemmaDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
NormalizedConfig,
Expand Down Expand Up @@ -240,6 +241,12 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class GemmaOnnxConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
pass


class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
MODEL_TYPES_REQUIRING_POSITION_IDS = {
"codegen",
"falcon",
"gemma",
"gpt2",
"gpt-bigcode",
"gpt-neo",
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,14 @@ class TasksManager:
onnx="FlaubertOnnxConfig",
tflite="FlaubertTFLiteConfig",
),
"gemma": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="GemmaOnnxConfig",
),
"glpn": supported_tasks_mapping(
"feature-extraction",
"depth-estimation",
Expand Down
7 changes: 5 additions & 2 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,14 @@ def prepare_past_key_values(
# Generate dummy past for the first forward if uses a merged decoder
if past_key_values is None:
batch_size = input_ids.shape[0]
if self.model_type in {"mistral", "llama"}:
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"gemma", "mistral", "llama"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

dtype = constructor.float16 if self.use_fp16 else constructor.float32

Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
FalconDummyPastKeyValuesGenerator,
GemmaDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
MultiQueryPastKeyValuesGenerator,
Expand Down
38 changes: 38 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,44 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
]


class GemmaDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = normalized_config.num_key_value_heads
self.head_dim = normalized_config.head_dim

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.head_dim,
)
return [
(
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]


class DummySpeechT5InputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ("output_sequence", "speaker_embeddings", "spectrogram")

Expand Down
7 changes: 4 additions & 3 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,20 @@ class NormalizedConfigManager:
"donut-swin": NormalizedVisionConfig,
"electra": NormalizedTextConfig,
"encoder-decoder": NormalizedEncoderDecoderConfig,
"gemma": NormalizedTextConfigWithGQA,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPTBigCodeNormalizedTextConfig,
"gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt-neox": NormalizedTextConfig,
"llama": NormalizedTextConfigWithGQA,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"llama": NormalizedTextConfigWithGQA,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mistral": NormalizedTextConfigWithGQA,
"mixtral": NormalizedTextConfigWithGQA,
"mpt": MPTNormalizedTextConfig,
"mt5": T5LikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
Expand All @@ -255,12 +257,11 @@ class NormalizedConfigManager:
"splinter": NormalizedTextConfig,
"t5": T5LikeNormalizedTextConfig,
"trocr": TrOCRLikeNormalizedTextConfig,
"whisper": WhisperLikeNormalizedTextConfig,
"vision-encoder-decoder": NormalizedEncoderDecoderConfig,
"vit": NormalizedVisionConfig,
"whisper": WhisperLikeNormalizedTextConfig,
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"mpt": MPTNormalizedTextConfig,
}

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
"fxmarty/tiny-testing-falcon-alibi": ["text-generation", "text-generation-with-past"],
},
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"glpn": "hf-internal-testing/tiny-random-GLPNModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
Expand Down Expand Up @@ -217,6 +218,7 @@
"electra": "google/electra-base-generator",
"encoder-decoder": "patrickvonplaten/bert2bert_cnn_daily_mail",
"flaubert": "hf-internal-testing/tiny-random-flaubert", # TODO
"gemma": "google/gemma-2b",
"gpt2": "gpt2",
"gpt-neo": "EleutherAI/gpt-neo-125M",
"gpt-neox": "EleutherAI/gpt-neox-20b",
Expand Down
13 changes: 8 additions & 5 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,6 +2236,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"bloom",
"codegen",
"falcon",
"gemma",
"gpt2",
"gpt_bigcode",
"gpt_neo",
Expand Down Expand Up @@ -2300,7 +2301,9 @@ def test_merge_from_onnx_and_save(self, model_arch):
model_id = MODEL_NAMES[model_arch]
task = "text-generation-with-past"

if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"):
if task not in TasksManager.get_supported_tasks_for_model_type(
model_arch.replace("_", "-"), exporter="onnx", library_name="transformers"
):
self.skipTest("Unsupported export case")

with tempfile.TemporaryDirectory() as tmpdir:
Expand Down Expand Up @@ -3626,7 +3629,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_merge_from_transformers_and_save(self, model_arch):
if "text2text-generation-with-past" not in TasksManager.get_supported_tasks_for_model_type(
model_arch.replace("_", "-"), exporter="onnx"
model_arch.replace("_", "-"), exporter="onnx", library_name="transformers"
):
self.skipTest("Unsupported -with-past export case")

Expand Down Expand Up @@ -3656,7 +3659,7 @@ def test_merge_from_onnx_and_save(self, model_arch):
task = "text2text-generation-with-past"

if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"):
self.skipTest("Unsupported export case")
self.skipTest("Unsupported export case", library_name="transformers")

model_ids = self._get_model_ids(model_arch)
for model_id in model_ids:
Expand Down Expand Up @@ -4192,7 +4195,7 @@ def _generate_random_audio_data(self):
@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_merge_from_transformers_and_save(self, model_arch):
if "automatic-speech-recognition-with-past" not in TasksManager.get_supported_tasks_for_model_type(
model_arch.replace("_", "-"), exporter="onnx"
model_arch.replace("_", "-"), exporter="onnx", library_name="transformers"
):
self.skipTest("Unsupported -with-past export case")

Expand All @@ -4214,7 +4217,7 @@ def test_merge_from_onnx_and_save(self, model_arch):
task = "automatic-speech-recognition-with-past"

if task not in TasksManager.get_supported_tasks_for_model_type(model_arch.replace("_", "-"), exporter="onnx"):
self.skipTest("Unsupported export case")
self.skipTest("Unsupported export case", library_name="transformers")

with tempfile.TemporaryDirectory() as tmpdir:
main_export(model_id, tmpdir, task=task)
Expand Down
17 changes: 10 additions & 7 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
},
"falcon": "fxmarty/really-tiny-falcon-testing",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
Expand Down Expand Up @@ -176,22 +177,24 @@ def _setup(self, model_args: Dict):
model_arch = model_args["model_arch"]
model_arch_and_params = model_args["test_name"]

model_ids = MODEL_NAMES[model_arch]
if isinstance(model_ids, dict):
model_ids = list(model_ids.keys())
else:
model_ids = [model_ids]

# TODO: this should actually be checked in ORTModel!
task = self.TASK
if "use_cache" in model_args and model_args["use_cache"] is True:
task = task + "-with-past"

library_name = TasksManager.infer_library_from_model(model_ids[0])

if "use_cache" in model_args and task not in TasksManager.get_supported_tasks_for_model_type(
model_arch.replace("_", "-"), exporter="onnx"
model_arch.replace("_", "-"), exporter="onnx", library_name=library_name
):
self.skipTest("Unsupported export case")

model_ids = MODEL_NAMES[model_arch]
if isinstance(model_ids, dict):
model_ids = list(model_ids.keys())
else:
model_ids = [model_ids]

if model_arch_and_params not in self.onnx_model_dirs:
self.onnx_model_dirs[model_arch_and_params] = {}

Expand Down

0 comments on commit e0cbf7d

Please sign in to comment.