Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Onnx granite #2043

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
pass


class GraniteOnnxConfig(LlamaOnnxConfig):
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")
MIN_TORCH_VERSION = version.parse("2.5.0")


class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"phi",
"phi3",
"qwen2",
"granite",
}


Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,13 @@ class TasksManager:
"text-classification",
onnx="LlamaOnnxConfig",
),
"granite": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
onnx="GraniteOnnxConfig",
),
"pegasus": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def prepare_past_key_values(
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"mistral", "llama", "qwen2"}:
elif self.model_type in {"mistral", "llama", "qwen2", "granite"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ORTConfigManager:
"gpt-neo": "gpt2",
"gpt-neox": "gpt2",
"gptj": "gpt2",
"granite": "gpt2",
# longt5 with O4 results in segmentation fault
"longt5": "bert",
"llama": "gpt2",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class NormalizedConfigManager:
"xlm-roberta": NormalizedTextConfig,
"yolos": NormalizedVisionConfig,
"qwen2": NormalizedTextConfig,
"granite": NormalizedTextConfigWithGQA,
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
"imagegpt": "hf-internal-testing/tiny-random-ImageGPTModel",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"gpt_neo",
"gpt_neox",
"gptj",
"granite",
"llama",
"mistral",
"mpt",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM",
"granite": "hf-internal-testing/tiny-random-GraniteForCausalLM",
"groupvit": "hf-internal-testing/tiny-random-groupvit",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"ibert": "hf-internal-testing/tiny-random-IBertModel",
Expand Down