diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index 747e1396fb..8207fda8e5 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -28,6 +28,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - Camembert - CLIP - CodeGen +- Cohere - ConvBert - ConvNext - ConvNextV2 diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e23716d4b7..787e918f26 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -281,6 +281,10 @@ class Qwen2OnnxConfig(LlamaOnnxConfig): pass +class CohereOnnxConfig(LlamaOnnxConfig): + pass + + class GemmaOnnxConfig(LlamaOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 8ecba9231f..4cfee11c3f 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -73,6 +73,7 @@ MODEL_TYPES_REQUIRING_POSITION_IDS = { "codegen", + "cohere", "falcon", "gemma", "gpt2", diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 608b3df0d7..49d1abef4f 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -453,6 +453,14 @@ class TasksManager: "text-generation-with-past", onnx="CodeGenOnnxConfig", ), + "cohere": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + onnx="CohereOnnxConfig", + ), "convbert": supported_tasks_mapping( "feature-extraction", "fill-mask", diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 2d9be2d757..882954e43d 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -339,7 +339,7 @@ def prepare_past_key_values( if self.model_type == "gemma": num_attention_heads = self.normalized_config.num_key_value_heads embed_size_per_head = self.normalized_config.head_dim - elif self.model_type in {"mistral", "llama", "qwen2"}: + elif self.model_type in {"mistral", "llama", "cohere", "qwen2"}: num_attention_heads = self.normalized_config.num_key_value_heads else: num_attention_heads = self.normalized_config.num_attention_heads diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 37d0feefcc..047fb4ceed 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -107,6 +107,7 @@ class ORTConfigManager: "bloom": "gpt2", "camembert": "bert", "codegen": "gpt2", + "cohere": "gpt2", "deberta": "bert", "deberta-v2": "bert", "distilbert": "bert", diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index dae5b5d633..a2eb802a20 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -20,6 +20,7 @@ "blenderbot", "blenderbot-small", "bloom", + "cohere", "llama", "mistral", "mpt", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 0c52754ff6..4967c2e94d 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -56,7 +56,7 @@ "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "clip": "hf-internal-testing/tiny-random-CLIPModel", - "convbert": "hf-internal-testing/tiny-random-ConvBertModel", + "cohere": "hf-internal-testing/tiny-random-CohereModel", "convnext": "hf-internal-testing/tiny-random-convnext", "convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model", "codegen": "hf-internal-testing/tiny-random-CodeGenModel", @@ -210,6 +210,7 @@ "bloom": "hf-internal-testing/tiny-random-BloomModel", # Not using bigscience/bloom-560m because it goes OOM. "camembert": "camembert-base", "clip": "openai/clip-vit-base-patch32", + "cohere": "hf-internal-testing/tiny-random-CohereModel", # Not using CohereForAI/c4ai-command-r-plus because it is gated and/or goes OOM. "convbert": "YituTech/conv-bert-base", "convnext": "facebook/convnext-tiny-224", "codegen": "hf-internal-testing/tiny-random-CodeGenModel", # Not using Salesforce/codegen-350M-multi because it takes too much time for testing. diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 3fe2c5e14d..6fd2d3cc4f 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -2248,6 +2248,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = [ "bloom", "codegen", + "cohere", "falcon", "gemma", "gpt2", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 6529826578..3848dbb530 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -38,6 +38,7 @@ "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "clip": "hf-internal-testing/tiny-random-CLIPModel", + "cohere": "hf-internal-testing/tiny-random-CohereModel", "convbert": "hf-internal-testing/tiny-random-ConvBertModel", "convnext": "hf-internal-testing/tiny-random-convnext", "convnextv2": "hf-internal-testing/tiny-random-ConvNextV2Model",