Skip to content

Commit

Permalink
fix issue with obtaining the decoder layer number when converting the…
Browse files Browse the repository at this point in the history
… T5 model. (#17185)

### Description
fix issue with obtaining the decoder layer number when converting the T5
model.

### Motivation and Context
fix issue: #17072

Test with
[byt5-small](https://huggingface.co/google/byt5-small/tree/main) model,
which has 12 encoder layers and 4 decoder layers.
Here is the log.

![image](https://github.com/microsoft/onnxruntime/assets/3481539/ff1b69c5-f485-4301-a333-9ee2a984df07)
  • Loading branch information
LitLeo authored Aug 18, 2023
1 parent 6ee4be7 commit 78b3565
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
20 changes: 12 additions & 8 deletions onnxruntime/python/tools/transformers/models/t5/t5_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def __init__(self, decoder, lm_head, config):
)

def forward(self, decoder_input_ids, encoder_attention_mask, *past):
past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers)
num_decoder_layers = self.config.num_decoder_layers
past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers)

# This is a hack since only the third dimension of encoder_hidden_states is used here
dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
Expand Down Expand Up @@ -162,7 +163,7 @@ def create_dummy(
T5DecoderInputs: dummy inputs for decoder
"""
num_attention_heads: int = config.num_heads
num_layers: int = config.num_layers
num_layers: int = config.num_decoder_layers
vocab_size: int = config.vocab_size

# Do not use head_size = hidden_size / num_attention_heads here.
Expand Down Expand Up @@ -263,9 +264,11 @@ def export_onnx(
)
input_list = inputs.to_list()

past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True)
present_self_names = present_names[: 2 * decoder.config.num_layers]
num_decoder_layers = decoder.config.num_decoder_layers

past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True)
present_self_names = present_names[: 2 * num_decoder_layers]

input_past_names = past_names if isinstance(decoder, T5Decoder) else []
output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
Expand Down Expand Up @@ -407,20 +410,21 @@ def verify_onnx(
torch_outputs = model(*input_list)

ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
num_decoder_layers = model.config.num_decoder_layers

max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
max_diff_all = max_diff
logger.debug(f"logits max_diff={max_diff}")

for i in range(2 * model.config.num_layers):
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)

if isinstance(model, T5DecoderInit):
for i in range(2 * model.config.num_layers):
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])
numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def export_onnx(
)
input_list = inputs.to_list()

present_names = PastKeyValuesHelper.get_past_names(model.config.num_layers, present=True)
present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)

output_names = ["logits", "encoder_hidden_states", *present_names]

Expand Down Expand Up @@ -271,6 +271,8 @@ def verify_onnx(
input_list = inputs.to_list()
torch_outputs = model(*input_list)

num_decoder_layers = model.config.num_decoder_layers

assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
logger.debug(f"logits max_diff={max_diff}")
Expand All @@ -281,13 +283,13 @@ def verify_onnx(
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)

for i in range(2 * model.config.num_layers):
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
logger.debug(f"self attention past state {i} max_diff={max_diff}")

for i in range(2 * model.config.num_layers):
for i in range(2 * num_decoder_layers):
max_diff = numpy.amax(
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * model.config.num_layers + i])
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
)
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
max_diff_all = max(max_diff_all, max_diff)
Expand Down

0 comments on commit 78b3565

Please sign in to comment.