Skip to content

Commit

Permalink
Add ONNX / ONNXRuntime support for StarCoder (#1042)
Browse files Browse the repository at this point in the history
* add gpt-bigcode

* add bigcode specific dummy generator

* normalize config

* test past key value flattened

* revert past key value separation

* fix

* add test to exporter

* add ort modeling support

* add modeling test

* fix bloom

* fix typo

---------

Co-authored-by: JingyaHuang <[email protected]>
  • Loading branch information
JingyaHuang and JingyaHuang authored Aug 1, 2023
1 parent 26c7549 commit f59af6c
Show file tree
Hide file tree
Showing 10 changed files with 114 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Supported architectures:
- Electra
- Flaubert
- GPT-2
- GPT-BigCode
- GPT-J
- GPT-Neo
- GPT-NeoX
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,9 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
and self.use_cache_branch is not False
and "attention_mask" in dummy_inputs
):
past_length = dummy_inputs["past_key_values"][0][0].shape[2]
# Obtain the past sequence length from the value instead of the key (Bloom).
past_length = dummy_inputs["past_key_values"][0][1].shape[-2]

dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"],
desired_length=past_length + 1,
Expand Down
40 changes: 40 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
NormalizedVisionConfig,
logging,
)
from ...utils.normalized_config import NormalizedConfigManager
from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
from .config import (
AudioOnnxConfig,
Expand Down Expand Up @@ -268,6 +269,45 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
}


class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_value_shape = (
self.batch_size,
self.sequence_length,
self.hidden_size // self.num_attention_heads * 2,
)
return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)]


class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
GPTBigCodeDummyPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = GPTBigCodeDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gpt_bigcode")

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
# No dim for `n_head` when using multi-query attention
inputs_or_outputs[f"{name}.{i}.key_value"] = {
0: "batch_size",
1: decoder_sequence_name,
}

def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key_value"] = t


class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
encoder_shape = (
Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@ class TasksManager:
"token-classification",
onnx="GPT2OnnxConfig",
),
"gpt-bigcode": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
"token-classification",
onnx="GPTBigCodeOnnxConfig",
),
"gptj": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down
63 changes: 54 additions & 9 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from ..utils import NormalizedConfigManager
from ..utils.logging import warn_once
from .utils import get_ordered_input_names, logging
from .utils import MULTI_QUERY_ATTN_MODELS, get_ordered_input_names, logging


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -161,7 +161,15 @@ def __init__(
self.expected_key_symbolic_shape = None
self.expected_value_symbolic_shape = None
for output in self.session.get_outputs():
if ".key" in output.name:
# To handle the case of multi-query attn where key and value are concatenated
if ".key_value" in output.name:
expected_key_value_symbolic_shape = output.shape
self.expected_key_symbolic_shape = (
self.expected_value_symbolic_shape
) = expected_key_value_symbolic_shape[:-1] + [
expected_key_value_symbolic_shape[-1] // 2,
]
elif ".key" in output.name:
self.expected_key_symbolic_shape = output.shape
elif ".value" in output.name:
self.expected_value_symbolic_shape = output.shape
Expand Down Expand Up @@ -227,6 +235,14 @@ def prepare_inputs_for_merged(
past_key_values = tuple(
key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value]
)
elif self.parent_model.config.model_type in MULTI_QUERY_ATTN_MODELS:
shape_key_and_value = (batch_size, 1, embed_size_per_head * 2)
key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype)

if use_torch is True:
key_and_value = key_and_value.to(self.device)

past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names)))
else:
shape = (batch_size, num_attention_heads, 1, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)
Expand Down Expand Up @@ -288,6 +304,24 @@ def compute_past_key_values_output_shapes(

return {name: key_shape if "key" in name else value_shape for name in self.key_value_output_names}

def compute_past_key_values_output_shapes_mqa(
self,
input_ids: torch.Tensor,
use_cache_branch: Optional[bool],
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> Dict[str, List[int]]:
batch_size = input_ids.size(0)
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads

sequence_length = input_ids.size(1)
if past_key_values is not None and use_cache_branch is not False:
sequence_length += past_key_values[0].size(-2)

key_and_value_shape = (batch_size, sequence_length, embed_size_per_head * 2)

return {name: key_and_value_shape for name in self.key_value_output_names}

def forward(
self,
input_ids: torch.LongTensor,
Expand All @@ -300,8 +334,8 @@ def forward(
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

# Flatten the past_key_values
if past_key_values is not None:
# Flatten the past_key_values (no need to flatten for models using multi-query attn)
if past_key_values is not None and (self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS):
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
Expand All @@ -312,7 +346,11 @@ def forward(
)

if self.parent_model.use_io_binding:
known_output_shapes = self.compute_past_key_values_output_shapes(
if self.parent_model.config.model_type in MULTI_QUERY_ATTN_MODELS:
compute_past_key_values_output_shapes_func = self.compute_past_key_values_output_shapes_mqa
else:
compute_past_key_values_output_shapes_func = self.compute_past_key_values_output_shapes
known_output_shapes = compute_past_key_values_output_shapes_func(
input_ids,
use_cache_branch=use_cache_branch_tensor.item() if use_cache_branch_tensor is not None else None,
past_key_values=past_key_values,
Expand Down Expand Up @@ -357,8 +395,11 @@ def forward(
past_key_values += (output_buffers[name].view(output_shapes[name]),)

# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (self-attention key and value per decoder layer)
num_pkv = 2
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))
if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS:
num_pkv = 2
past_key_values = tuple(
past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)
)

logits = output_buffers["logits"].view(output_shapes["logits"])

Expand Down Expand Up @@ -410,8 +451,12 @@ def forward(

# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and
# per decoder layer
num_pkv = 2
past_key_values = tuple(past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv))
if self.parent_model.config.model_type not in MULTI_QUERY_ATTN_MODELS:
num_pkv = 2
past_key_values = tuple(
past_key_values[i : i + num_pkv] for i in range(0, len(past_key_values), num_pkv)
)

logits = torch.from_numpy(outputs[self.output_names["logits"]]).to(self.device)

loss = None
Expand Down
3 changes: 3 additions & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
"tensor(double)": np.float64,
}

MULTI_QUERY_ATTN_MODELS = {"gpt_bigcode"}


def _is_gpu_available():
"""
Expand Down Expand Up @@ -109,6 +111,7 @@ class ORTConfigManager:
"distilbert": "bert",
"electra": "bert",
"gpt2": "gpt2",
"gpt_bigcode": "gpt2",
"gpt_neo": "gpt2",
"gpt_neox": "gpt2",
"gptj": "gpt2",
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class NormalizedConfigManager:
"donut-swin": NormalizedVisionConfig,
"electra": NormalizedTextConfig,
"gpt2": GPT2LikeNormalizedTextConfig,
"gpt-bigcode": GPT2LikeNormalizedTextConfig,
"gpt_neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt_neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt-neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,7 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
"bloom",
"codegen",
"gpt2",
"gpt_bigcode",
"gpt_neo",
"gpt_neox",
"gptj",
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
Expand Down

0 comments on commit f59af6c

Please sign in to comment.