diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 2d28abd6a3..18ea9ef772 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -42,6 +42,7 @@ Supported architectures: - Electra - Flaubert - GPT-2 +- GPT-BigCode - GPT-J - GPT-Neo - GPT-NeoX diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index f5fd3dcafc..cb6abf1d28 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -214,6 +214,11 @@ class LlamaOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig +class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig): + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt"): past_key_shape = ( diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index d6e6676de0..b1239c7d9f 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -492,6 +492,15 @@ class TasksManager: "token-classification", onnx="GPT2OnnxConfig", ), + "gpt-bigcode": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + "token-classification", + onnx="GPTBigCodeOnnxConfig", + ), "gptj": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 8635ee8ee7..a584ddf9ea 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -109,6 +109,7 @@ class ORTConfigManager: "distilbert": "bert", "electra": "bert", "gpt2": "gpt2", + "gpt-bigcode": "gpt2", "gpt_neo": "gpt2", "gpt_neox": "gpt2", "gptj": "gpt2", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 24c51d3335..3ed834d4bc 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -200,6 +200,7 @@ class NormalizedConfigManager: "donut-swin": NormalizedVisionConfig, "electra": NormalizedTextConfig, "gpt2": GPT2LikeNormalizedTextConfig, + "gpt-bigcode": GPT2LikeNormalizedTextConfig, "gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), "gpt_neox": NormalizedTextConfig, "llama": NormalizedTextConfig,