Skip to content

Commit

Permalink
Add --mask_type option to generate different format of attention mask…
Browse files Browse the repository at this point in the history
… in bert_perf_test.py (#16976)

### Description
Add an option to generate different formats of attention_mask for
testing transformers models:
1 - 1D mask index, actual sequence length excluding padding
2 - 2D attention mask. Value 0 means padding, 1 otherwise.
3 - 1D, key lengths and cumulated sequence lengths of query and key

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
tianleiwu authored Aug 3, 2023
1 parent bda012a commit a25d0d2
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 16 deletions.
13 changes: 12 additions & 1 deletion onnxruntime/python/tools/transformers/bert_perf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ModelSetting:
opt_level: int
input_tuning_results: Optional[str]
output_tuning_results: Optional[str]
mask_type: int


def create_session(
Expand Down Expand Up @@ -369,6 +370,7 @@ def run_performance(model_setting, test_setting, perf_results):
input_mask,
test_setting.average_sequence_length,
test_setting.random_sequence_length,
mask_type=model_setting.mask_type,
)

run_perf_tests(model_setting, test_setting, perf_results, all_inputs)
Expand Down Expand Up @@ -524,6 +526,14 @@ def parse_arguments():
)
parser.set_defaults(random_sequence_length=False)

parser.add_argument(
"--mask_type",
required=False,
type=int,
default=2,
help="mask type: (1: mask index or sequence length, 2: raw 2D mask, 3: key len, cumulated lengths of query and key)",
)

args = parser.parse_args()
return args

Expand All @@ -541,7 +551,7 @@ def main():
perf_results = manager.dict()

batch_size_set = set(args.batch_size)
if not min(batch_size_set) >= 1 and max(batch_size_set) <= 128:
if not (min(batch_size_set) >= 1 and max(batch_size_set) <= 128):
raise Exception("batch_size not in range [1, 128]")

model_setting = ModelSetting(
Expand All @@ -552,6 +562,7 @@ def main():
args.opt_level,
args.input_tuning_results,
args.output_tuning_results,
args.mask_type,
)

for batch_size in batch_size_set:
Expand Down
78 changes: 64 additions & 14 deletions onnxruntime/python/tools/transformers/bert_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,23 @@ def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_le
return data


def get_random_length(max_sequence_length: int, average_sequence_length: int):
assert average_sequence_length >= 1 and average_sequence_length <= max_sequence_length

# For uniform distribution, we find proper lower and upper bounds so that the average is in the middle.
if 2 * average_sequence_length > max_sequence_length:
return random.randint(2 * average_sequence_length - max_sequence_length, max_sequence_length)
else:
return random.randint(1, 2 * average_sequence_length - 1)


def fake_input_mask_data(
input_mask: TensorProto,
batch_size: int,
sequence_length: int,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int = 2,
) -> np.ndarray:
"""Create input tensor based on the graph input of segment_ids.
Expand All @@ -89,6 +100,9 @@ def fake_input_mask_data(
sequence_length (int): sequence length
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type (int): mask type - 1: mask index (sequence length excluding paddings). Shape is (batch_size).
2: 2D attention mask. Shape is (batch_size, sequence_length).
3: key len, cumulated lengths of query and key. Shape is (3 * batch_size + 2).
Returns:
np.ndarray: the input tensor created
Expand All @@ -100,20 +114,40 @@ def fake_input_mask_data(
TensorProto.INT64,
]

data = np.zeros((batch_size, sequence_length), dtype=np.int32)
if random_sequence_length:
for i in range(batch_size):
# We use uniform distribution, so we find proper minimal and maximal so that the average is in the middle.
if 2 * average_sequence_length > sequence_length:
actual_seq_len = random.randint(2 * average_sequence_length - sequence_length, sequence_length)
else:
actual_seq_len = random.randint(1, 2 * average_sequence_length - 1)

for j in range(actual_seq_len):
data[i, j] = 1
if mask_type == 1: # sequence length excluding paddings
data = np.ones((batch_size), dtype=np.int32)
if random_sequence_length:
for i in range(batch_size):
data[i] = get_random_length(sequence_length, average_sequence_length)
else:
for i in range(batch_size):
data[i] = average_sequence_length
elif mask_type == 2: # 2D attention mask
data = np.zeros((batch_size, sequence_length), dtype=np.int32)
if random_sequence_length:
for i in range(batch_size):
actual_seq_len = get_random_length(sequence_length, average_sequence_length)
for j in range(actual_seq_len):
data[i, j] = 1
else:
temp = np.ones((batch_size, average_sequence_length), dtype=np.int32)
data[: temp.shape[0], : temp.shape[1]] = temp
else:
temp = np.ones((batch_size, average_sequence_length), dtype=np.int32)
data[: temp.shape[0], : temp.shape[1]] = temp
assert mask_type == 3
data = np.zeros((batch_size * 3 + 2), dtype=np.int32)
if random_sequence_length:
for i in range(batch_size):
data[i] = get_random_length(sequence_length, average_sequence_length)

for i in range(batch_size + 1):
data[batch_size + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
data[2 * batch_size + 1 + i] = data[batch_size + i - 1] + data[i - 1] if i > 0 else 0
else:
for i in range(batch_size):
data[i] = average_sequence_length
for i in range(batch_size + 1):
data[batch_size + i] = i * average_sequence_length
data[2 * batch_size + 1 + i] = i * average_sequence_length

if input_mask.type.tensor_type.elem_type == TensorProto.FLOAT:
data = np.float32(data)
Expand Down Expand Up @@ -160,6 +194,7 @@ def fake_test_data(
input_mask: TensorProto,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int,
):
"""Create given number of input data for testing
Expand All @@ -175,6 +210,7 @@ def fake_test_data(
input_mask (TensorProto): graph input of attention mask
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
Returns:
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
Expand All @@ -195,7 +231,7 @@ def fake_test_data(

if input_mask:
inputs[input_mask.name] = fake_input_mask_data(
input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length
input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length, mask_type
)

if verbose and len(all_inputs) == 0:
Expand All @@ -215,6 +251,7 @@ def generate_test_data(
input_mask: TensorProto,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int,
):
"""Create given number of input data for testing
Expand All @@ -229,6 +266,7 @@ def generate_test_data(
input_mask (TensorProto): graph input of attention mask
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key
Returns:
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
Expand All @@ -247,6 +285,7 @@ def generate_test_data(
input_mask,
average_sequence_length,
random_sequence_length,
mask_type,
)
if len(all_inputs) != test_cases:
print("Failed to create test data for test.")
Expand Down Expand Up @@ -474,6 +513,14 @@ def parse_arguments():
)
parser.set_defaults(random_sequence_length=False)

parser.add_argument(
"--mask_type",
required=False,
type=int,
default=2,
help="mask type: (1: mask index, 2: raw 2D mask, 3: key lengths, cumulated lengths of query and key)",
)

args = parser.parse_args()
return args

Expand All @@ -492,6 +539,7 @@ def create_and_save_test_data(
only_input_tensors: bool,
average_sequence_length: int,
random_sequence_length: bool,
mask_type: int,
):
"""Create test data for a model, and save test data to a directory.
Expand All @@ -509,6 +557,7 @@ def create_and_save_test_data(
only_input_tensors (bool): only save input tensors,
average_sequence_length (int): average sequence length excluding paddings
random_sequence_length (bool): whether use uniform random number for sequence length
mask_type(int): mask type
"""
input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name)

Expand All @@ -523,6 +572,7 @@ def create_and_save_test_data(
input_mask,
average_sequence_length,
random_sequence_length,
mask_type,
)

for i, inputs in enumerate(all_inputs):
Expand Down
15 changes: 14 additions & 1 deletion onnxruntime/python/tools/transformers/compare_bert_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ def run_test(
input_ids_name,
segment_ids_name,
input_mask_name,
mask_type,
):
# Try deduce input names from optimized model.
input_ids, segment_ids, input_mask = get_bert_inputs(
optimized_model, input_ids_name, segment_ids_name, input_mask_name
)

# Use random mask length for accuracy test. It might introduce slight inflation in latency reported in this script.
average_sequence_length = int(sequence_length / 2) if sequence_length >= 2 else sequence_length
all_inputs = generate_test_data(
batch_size,
sequence_length,
Expand All @@ -105,7 +107,9 @@ def run_test(
input_ids,
segment_ids,
input_mask,
random_mask_length=True,
average_sequence_length,
True, # random sequence length
mask_type,
)

baseline_results, baseline_latency, output_names = run_model(
Expand Down Expand Up @@ -208,6 +212,14 @@ def parse_arguments():
help="input name for attention mask",
)

parser.add_argument(
"--mask_type",
required=False,
type=int,
default=2,
help="mask type: (1: mask index or sequence length, 2: raw 2D mask, 3: key len, cumulated lengths of query and key)",
)

args = parser.parse_args()
return args

Expand Down Expand Up @@ -235,6 +247,7 @@ def main():
args.input_ids,
args.segment_ids,
args.input_mask,
args.mask_type,
)


Expand Down

0 comments on commit a25d0d2

Please sign in to comment.