Skip to content

Commit

Permalink
Update GPTBigCodeOnnxConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jul 27, 2023
1 parent b03c369 commit 17a4a08
Showing 1 changed file with 4 additions and 39 deletions.
43 changes: 4 additions & 39 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):
pass


class GPTBigCodeOnnxConfig(GPT2OnnxConfig):
pass


class GPTNeoOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")
Expand Down Expand Up @@ -268,45 +272,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
}


class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_value_shape = (
self.batch_size,
self.sequence_length,
self.hidden_size // self.num_attention_heads * 2,
)
return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)]


class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
GPTBigCodeDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Refer to OnnxConfigWithPast in base.py
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
# No dim for `n_head` when using multi-query attention
inputs_or_outputs[f"{name}.{i}.key_value"] = {
0: "batch_size",
1: decoder_sequence_name,
}


class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
encoder_shape = (
Expand Down

0 comments on commit 17a4a08

Please sign in to comment.