Skip to content

Commit

Permalink
Add cohere ONNX export support
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jun 12, 2024
1 parent db51410 commit a98a15d
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Camembert
- CLIP
- CodeGen
- Cohere
- ConvBert
- ConvNext
- ConvNextV2
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ class Qwen2OnnxConfig(LlamaOnnxConfig):
pass


class CohereOnnxConfig(LlamaOnnxConfig):
pass


class GemmaOnnxConfig(LlamaOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

MODEL_TYPES_REQUIRING_POSITION_IDS = {
"codegen",
"cohere",
"falcon",
"gemma",
"gpt2",
Expand Down
8 changes: 8 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,14 @@ class TasksManager:
"text-generation-with-past",
onnx="CodeGenOnnxConfig",
),
"cohere": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
onnx="CohereOnnxConfig",
),
"convbert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
2 changes: 1 addition & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def prepare_past_key_values(
if self.model_type == "gemma":
num_attention_heads = self.normalized_config.num_key_value_heads
embed_size_per_head = self.normalized_config.head_dim
elif self.model_type in {"mistral", "llama", "qwen2"}:
elif self.model_type in {"mistral", "llama", "cohere", "qwen2"}:
num_attention_heads = self.normalized_config.num_key_value_heads
else:
num_attention_heads = self.normalized_config.num_attention_heads
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ORTConfigManager:
"bloom": "gpt2",
"camembert": "bert",
"codegen": "gpt2",
"cohere": "gpt2",
"deberta": "bert",
"deberta-v2": "bert",
"distilbert": "bert",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"blenderbot",
"blenderbot-small",
"bloom",
"cohere",
"llama",
"mistral",
"mpt",
Expand Down
3 changes: 2 additions & 1 deletion tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"cohere": "hf-internal-testing/tiny-random-CohereModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel",
Expand Down Expand Up @@ -210,6 +210,7 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel", # Not using bigscience/bloom-560m because it goes OOM.
"camembert": "camembert-base",
"clip": "openai/clip-vit-base-patch32",
"cohere": "hf-internal-testing/tiny-random-CohereModel", # Not using CohereForAI/c4ai-command-r-plus because it is gated and/or goes OOM.
"convbert": "YituTech/conv-bert-base",
"convnext": "facebook/convnext-tiny-224",
"codegen": "hf-internal-testing/tiny-random-CodeGenModel", # Not using Salesforce/codegen-350M-multi because it takes too much time for testing.
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
"bloom",
"codegen",
"cohere",
"falcon",
"gemma",
"gpt2",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"bloom": "hf-internal-testing/tiny-random-BloomModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
"clip": "hf-internal-testing/tiny-random-CLIPModel",
"cohere": "hf-internal-testing/tiny-random-CohereModel",
"convbert": "hf-internal-testing/tiny-random-ConvBertModel",
"convnext": "hf-internal-testing/tiny-random-convnext",
"convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",
Expand Down

0 comments on commit a98a15d

Please sign in to comment.