Skip to content

Commit

Permalink
add bigcode specific dummy generator
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed May 9, 2023
1 parent 862c81d commit bd187ae
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
3 changes: 2 additions & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,8 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
and self.use_cache_branch is not False
and "attention_mask" in dummy_inputs
):
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
past_length = dummy_inputs["past_key_values"][0][0].shape[-2]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
Expand Down
44 changes: 39 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,6 @@ class LlamaOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig):
DEFAULT_ONNX_OPSET = 13
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_shape = (
Expand Down Expand Up @@ -272,6 +267,45 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
}


class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_value_shape = (
self.batch_size,
self.sequence_length,
self.hidden_size // self.num_attention_heads * 2,
)
return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)]


class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
GPTBigCodeDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Refer to OnnxConfigWithPast in base.py
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
# No dim for `n_head` when using multi-query attention
inputs_or_outputs[f"{name}.{i}.key_value"] = {
0: "batch_size",
1: decoder_sequence_name,
}


class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
encoder_shape = (
Expand Down

0 comments on commit bd187ae

Please sign in to comment.