From bd187ae97856910a623c305385316dd8c316b190 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 9 May 2023 22:53:46 +0000 Subject: [PATCH] add bigcode specific dummy generator --- optimum/exporters/onnx/base.py | 3 +- optimum/exporters/onnx/model_configs.py | 44 ++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 60a9b3cce3..9338835787 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -583,7 +583,8 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): and self.use_cache_branch is not False and "attention_mask" in dummy_inputs ): - past_length = dummy_inputs["past_key_values"][0][0].shape[2] + past_length = dummy_inputs["past_key_values"][0][0].shape[-2] + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["attention_mask"], desired_length=past_length + 1, diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cb6abf1d28..e8b0a827de 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -214,11 +214,6 @@ class LlamaOnnxConfig(TextDecoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedTextConfig -class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig): - DEFAULT_ONNX_OPSET = 13 - NORMALIZED_CONFIG_CLASS = NormalizedTextConfig - - class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt"): past_key_shape = ( @@ -272,6 +267,45 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire } +class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def generate(self, input_name: str, framework: str = "pt"): + past_key_value_shape = ( + self.batch_size, + self.sequence_length, + self.hidden_size // self.num_attention_heads * 2, + ) + return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)] + + +class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig): + DUMMY_INPUT_GENERATOR_CLASSES = ( + GPTBigCodeDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Refer to OnnxConfigWithPast in base.py + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + # No dim for `n_head` when using multi-query attention + inputs_or_outputs[f"{name}.{i}.key_value"] = { + 0: "batch_size", + 1: decoder_sequence_name, + } + + class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt"): encoder_shape = (