Skip to content

Commit

Permalink
add gpt-bigcode
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed May 8, 2023
1 parent d29e582 commit 862c81d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Supported architectures:
- Electra
- Flaubert
- GPT-2
- GPT-BigCode
- GPT-J
- GPT-Neo
- GPT-NeoX
Expand Down
5 changes: 5 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,11 @@ class LlamaOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,15 @@ class TasksManager:
"token-classification",
onnx="GPT2OnnxConfig",
),
"gpt-bigcode": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
"token-classification",
onnx="GPTBigCodeOnnxConfig",
),
"gptj": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class ORTConfigManager:
"distilbert": "bert",
"electra": "bert",
"gpt2": "gpt2",
"gpt-bigcode": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gptj": "gpt2",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class NormalizedConfigManager:
"donut-swin": NormalizedVisionConfig,
"electra": NormalizedTextConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
Expand Down

0 comments on commit 862c81d

Please sign in to comment.