diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index fe1e52a050..e70e8afa84 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -42,6 +42,7 @@ Supported architectures: - Electra - Flaubert - GPT-2 +- GPT-BigCode - GPT-J - GPT-Neo - GPT-NeoX diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index d93c0b29f6..3c50389726 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -596,7 +596,9 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): and self.use_cache_branch is not False and "attention_mask" in dummy_inputs ): - past_length = dummy_inputs["past_key_values"][0][0].shape[2] + # Obtain the past sequence length from the value instead of the key (Bloom). + past_length = dummy_inputs["past_key_values"][0][1].shape[-2] + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["attention_mask"], desired_length=past_length + 1, diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 1192b25bba..4b3483dbd4 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -39,6 +39,7 @@ NormalizedVisionConfig, logging, ) +from ...utils.normalized_config import NormalizedConfigManager from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from .config import ( AudioOnnxConfig, @@ -268,6 +269,45 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire } +class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def generate(self, input_name: str, framework: str = "pt"): + past_key_value_shape = ( + self.batch_size, + self.sequence_length, + self.hidden_size // self.num_attention_heads * 2, + ) + return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)] + + +class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + GPTBigCodeDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gpt_bigcode") + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + # No dim for `n_head` when using multi-query attention + inputs_or_outputs[f"{name}.{i}.key_value"] = { + 0: "batch_size", + 1: decoder_sequence_name, + } + + def flatten_past_key_values(self, flattened_output, name, idx, t): + flattened_output[f"{name}.{idx}.key_value"] = t + + class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt"): encoder_shape = ( diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 3afc31873c..54a36f06e9 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -516,6 +516,15 @@ class TasksManager: "token-classification", onnx="GPT2OnnxConfig", ), + "gpt-bigcode": supported_tasks_mapping( + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + "token-classification", + onnx="GPTBigCodeOnnxConfig", + ), "gptj": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/onnxruntime/base.py b/optimum/onnxruntime/base.py index 59a21f944d..f5e2776981 100644 --- a/optimum/onnxruntime/base.py +++ b/optimum/onnxruntime/base.py @@ -24,7 +24,7 @@ from ..utils import NormalizedConfigManager from ..utils.logging import warn_once -from .utils import get_ordered_input_names, logging +from .utils import MULTI_QUERY_ATTN_MODELS, get_ordered_input_names, logging logger = logging.get_logger(__name__) @@ -161,7 +161,15 @@ def __init__( self.expected_key_symbolic_shape = None self.expected_value_symbolic_shape = None for output in self.session.get_outputs(): - if ".key" in output.name: + # To handle the case of multi-query attn where key and value are concatenated + if ".key_value" in output.name: + expected_key_value_symbolic_shape = output.shape + self.expected_key_symbolic_shape = ( + self.expected_value_symbolic_shape + ) = expected_key_value_symbolic_shape[:-1] + [ + expected_key_value_symbolic_shape[-1] // 2, + ] + elif ".key" in output.name: self.expected_key_symbolic_shape = output.shape elif ".value" in output.name: self.expected_value_symbolic_shape = output.shape @@ -227,6 +235,14 @@ def prepare_inputs_for_merged( past_key_values = tuple( key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value] ) + elif self.parent_model.config.model_type in MULTI_QUERY_ATTN_MODELS: + shape_key_and_value = (batch_size, 1, embed_size_per_head * 2) + key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype) + + if use_torch is True: + key_and_value = key_and_value.to(self.device) + + past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names))) else: shape = (batch_size, num_attention_heads, 1, embed_size_per_head) key_or_value = constructor.zeros(shape, dtype=dtype) @@ -288,6 +304,24 @@ def compute_past_key_values_output_shapes( return {name: key_shape if "key" in name else value_shape for name in self.key_value_output_names} + def compute_past_key_values_output_shapes_mqa( + self, + input_ids: torch.Tensor, + use_cache_branch: Optional[bool], + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + ) -> Dict[str, List[int]]: + batch_size = input_ids.size(0) + num_attention_heads = self.normalized_config.num_attention_heads + embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads + + sequence_length = input_ids.size(1) + if past_key_values is not None and use_cache_branch is not False: + sequence_length += past_key_values[0].size(-2) + + key_and_value_shape = (batch_size, sequence_length, embed_size_per_head * 2) + + return {name: key_and_value_shape for name in self.key_value_output_names} + def forward( self, input_ids: torch.LongTensor, @@ -300,8 +334,8 @@ def forward( use_torch = isinstance(input_ids, torch.Tensor) self.parent_model.raise_on_numpy_input_io_binding(use_torch) - # Flatten the past_key_values - if past_key_values is not None: + # Flatten the past_key_values (no need to flatten for models using multi-query attn) + if past_key_values is not None and (self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS): past_key_values = tuple( past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer ) @@ -312,7 +346,11 @@ def forward( ) if self.parent_model.use_io_binding: - known_output_shapes = self.compute_past_key_values_output_shapes( + if self.parent_model.config.model_type in MULTI_QUERY_ATTN_MODELS: + compute_past_key_values_output_shapes_func = self.compute_past_key_values_output_shapes_mqa + else: + compute_past_key_values_output_shapes_func = self.compute_past_key_values_output_shapes + known_output_shapes = compute_past_key_values_output_shapes_func( input_ids, use_cache_branch=use_cache_branch_tensor.item() if use_cache_branch_tensor is not None else None, past_key_values=past_key_values, @@ -357,8 +395,11 @@ def forward( past_key_values += (output_buffers[name].view(output_shapes[name]),) # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (self-attention key and value per decoder layer) - num_pkv = 2 - past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)) + if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS: + num_pkv = 2 + past_key_values = tuple( + past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv) + ) logits = output_buffers["logits"].view(output_shapes["logits"]) @@ -410,8 +451,12 @@ def forward( # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and # per decoder layer - num_pkv = 2 - past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)) + if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS: + num_pkv = 2 + past_key_values = tuple( + past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv) + ) + logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device) loss = None diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index f8f6acbbdb..9a170fa3a0 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -53,6 +53,8 @@ "tensor(double)": np.float64, } +MULTI_QUERY_ATTN_MODELS = {"gpt_bigcode"} + def _is_gpu_available(): """ @@ -109,6 +111,7 @@ class ORTConfigManager: "distilbert": "bert", "electra": "bert", "gpt2": "gpt2", + "gpt_bigcode": "gpt2", "gpt_neo": "gpt2", "gpt_neox": "gpt2", "gptj": "gpt2", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index a5642f4b98..6da01ff8de 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -216,6 +216,7 @@ class NormalizedConfigManager: "donut-swin": NormalizedVisionConfig, "electra": NormalizedTextConfig, "gpt2": GPT2LikeNormalizedTextConfig, + "gpt-bigcode": GPT2LikeNormalizedTextConfig, "gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"), "gpt_neox": NormalizedTextConfig, "llama": NormalizedTextConfig, diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 423875ca28..ab4ce97b75 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -57,6 +57,7 @@ "electra": "hf-internal-testing/tiny-random-ElectraModel", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 6ffbbb7732..ab39351319 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1946,6 +1946,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): "bloom", "codegen", "gpt2", + "gpt_bigcode", "gpt_neo", "gpt_neox", "gptj", diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index f83acd91e6..09ada4e369 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -52,6 +52,7 @@ "electra": "hf-internal-testing/tiny-random-ElectraModel", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt2": "hf-internal-testing/tiny-random-gpt2", + "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel", "gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM", "gptj": "hf-internal-testing/tiny-random-GPTJModel",