From 17a4a0868e40f1f1889e8001cc51b2035b78133e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 27 Jul 2023 18:06:52 +0000 Subject: [PATCH] Update `GPTBigCodeOnnxConfig` --- optimum/exporters/onnx/model_configs.py | 43 +++---------------------- 1 file changed, 4 insertions(+), 39 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index eb5ff8fec8..1afc3272c4 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -195,6 +195,10 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig): pass +class GPTBigCodeOnnxConfig(GPT2OnnxConfig): + pass + + class GPTNeoOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads") @@ -268,45 +272,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire } -class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): - def generate(self, input_name: str, framework: str = "pt"): - past_key_value_shape = ( - self.batch_size, - self.sequence_length, - self.hidden_size // self.num_attention_heads * 2, - ) - return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)] - - -class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig): - DUMMY_INPUT_GENERATOR_CLASSES = ( - GPTBigCodeDummyPastKeyValuesGenerator, - ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES - DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator - NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") - - def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): - """ - Refer to OnnxConfigWithPast in base.py - """ - if direction not in ["inputs", "outputs"]: - raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - - if direction == "inputs": - decoder_sequence_name = "past_sequence_length" - name = "past_key_values" - else: - decoder_sequence_name = "past_sequence_length + 1" - name = "present" - - for i in range(self._normalized_config.num_layers): - # No dim for `n_head` when using multi-query attention - inputs_or_outputs[f"{name}.{i}.key_value"] = { - 0: "batch_size", - 1: decoder_sequence_name, - } - - class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt"): encoder_shape = (