From 76aff63f375be7e295968c3ea0eca932973a4ae0 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 2 Aug 2023 10:28:39 -0700 Subject: [PATCH] Update bert_perf_test to test inputs with different padding ratio (#16963) Add --average_sequence_length and --random_sequence_length so that we can test the performance of model on different padding ratio. --- .../tools/transformers/bert_perf_test.py | 30 +++++++- .../tools/transformers/bert_test_data.py | 75 +++++++++++++++---- .../models/longformer/generate_test_data.py | 38 +++++++++- 3 files changed, 122 insertions(+), 21 deletions(-) diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index ae0eb3f47fda3..c843831be6779 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -44,6 +44,8 @@ class TestSetting: seed: int verbose: bool log_severity: int + average_sequence_length: int + random_sequence_length: bool @dataclass @@ -365,7 +367,8 @@ def run_performance(model_setting, test_setting, perf_results): input_ids, segment_ids, input_mask, - random_mask_length=False, + test_setting.average_sequence_length, + test_setting.random_sequence_length, ) run_perf_tests(model_setting, test_setting, perf_results, all_inputs) @@ -473,6 +476,7 @@ def parse_arguments(): default=None, help="input name for input ids", ) + parser.add_argument( "--segment_ids_name", required=False, @@ -480,6 +484,7 @@ def parse_arguments(): default=None, help="input name for segment ids", ) + parser.add_argument( "--input_mask_name", required=False, @@ -494,6 +499,7 @@ def parse_arguments(): type=str, help="tuning results (json) to be loaded before benchmark", ) + parser.add_argument( "--output_tuning_results", default=None, @@ -501,6 +507,23 @@ def parse_arguments(): help="tuning results (json) to be saved after benchmark", ) + parser.add_argument( + "-a", + "--average_sequence_length", + default=-1, + type=int, + help="average sequence length excluding padding", + ) + + parser.add_argument( + "-r", + "--random_sequence_length", + required=False, + action="store_true", + help="use uniform random instead of fixed sequence length", + ) + parser.set_defaults(random_sequence_length=False) + args = parser.parse_args() return args @@ -511,6 +534,9 @@ def main(): if args.test_times == 0: args.test_times = max(1, int(1000 / args.samples)) + if args.average_sequence_length <= 0: + args.average_sequence_length = args.sequence_length + manager = multiprocessing.Manager() perf_results = manager.dict() @@ -541,6 +567,8 @@ def main(): args.seed, args.verbose, args.log_severity, + args.average_sequence_length, + args.random_sequence_length, ) print("test setting", test_setting) diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index 2d87b2d3ad3fd..bed9eb4dbc1f1 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -78,7 +78,8 @@ def fake_input_mask_data( input_mask: TensorProto, batch_size: int, sequence_length: int, - random_mask_length: bool, + average_sequence_length: int, + random_sequence_length: bool, ) -> np.ndarray: """Create input tensor based on the graph input of segment_ids. @@ -86,7 +87,8 @@ def fake_input_mask_data( input_mask (TensorProto): graph input of the attention mask input tensor batch_size (int): batch size sequence_length (int): sequence length - random_mask_length (bool): whether mask according to random padding length + average_sequence_length (int): average sequence length excluding paddings + random_sequence_length (bool): whether use uniform random number for sequence length Returns: np.ndarray: the input tensor created @@ -98,13 +100,20 @@ def fake_input_mask_data( TensorProto.INT64, ] - if random_mask_length: - actual_seq_len = random.randint(int(sequence_length * 2 / 3), sequence_length) - data = np.zeros((batch_size, sequence_length), dtype=np.int32) - temp = np.ones((batch_size, actual_seq_len), dtype=np.int32) - data[: temp.shape[0], : temp.shape[1]] = temp + data = np.zeros((batch_size, sequence_length), dtype=np.int32) + if random_sequence_length: + for i in range(batch_size): + # We use uniform distribution, so we find proper minimal and maximal so that the average is in the middle. + if 2 * average_sequence_length > sequence_length: + actual_seq_len = random.randint(2 * average_sequence_length - sequence_length, sequence_length) + else: + actual_seq_len = random.randint(1, 2 * average_sequence_length - 1) + + for j in range(actual_seq_len): + data[i, j] = 1 else: - data = np.ones((batch_size, sequence_length), dtype=np.int32) + temp = np.ones((batch_size, average_sequence_length), dtype=np.int32) + data[: temp.shape[0], : temp.shape[1]] = temp if input_mask.type.tensor_type.elem_type == TensorProto.FLOAT: data = np.float32(data) @@ -149,7 +158,8 @@ def fake_test_data( input_ids: TensorProto, segment_ids: TensorProto, input_mask: TensorProto, - random_mask_length: bool, + average_sequence_length: int, + random_sequence_length: bool, ): """Create given number of input data for testing @@ -163,7 +173,8 @@ def fake_test_data( input_ids (TensorProto): graph input of input IDs segment_ids (TensorProto): graph input of token type IDs input_mask (TensorProto): graph input of attention mask - random_mask_length (bool): whether mask random number of words at the end + average_sequence_length (int): average sequence length excluding paddings + random_sequence_length (bool): whether use uniform random number for sequence length Returns: List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary @@ -183,7 +194,9 @@ def fake_test_data( inputs[segment_ids.name] = fake_segment_ids_data(segment_ids, batch_size, sequence_length) if input_mask: - inputs[input_mask.name] = fake_input_mask_data(input_mask, batch_size, sequence_length, random_mask_length) + inputs[input_mask.name] = fake_input_mask_data( + input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length + ) if verbose and len(all_inputs) == 0: print("Example inputs", inputs) @@ -200,7 +213,8 @@ def generate_test_data( input_ids: TensorProto, segment_ids: TensorProto, input_mask: TensorProto, - random_mask_length: bool, + average_sequence_length: int, + random_sequence_length: bool, ): """Create given number of input data for testing @@ -213,7 +227,8 @@ def generate_test_data( input_ids (TensorProto): graph input of input IDs segment_ids (TensorProto): graph input of token type IDs input_mask (TensorProto): graph input of attention mask - random_mask_length (bool): whether mask random number of words at the end + average_sequence_length (int): average sequence length excluding paddings + random_sequence_length (bool): whether use uniform random number for sequence length Returns: List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary @@ -230,7 +245,8 @@ def generate_test_data( input_ids, segment_ids, input_mask, - random_mask_length, + average_sequence_length, + random_sequence_length, ) if len(all_inputs) != test_cases: print("Failed to create test data for test.") @@ -441,6 +457,23 @@ def parse_arguments(): ) parser.set_defaults(only_input_tensors=False) + parser.add_argument( + "-a", + "--average_sequence_length", + default=-1, + type=int, + help="average sequence length excluding padding", + ) + + parser.add_argument( + "-r", + "--random_sequence_length", + required=False, + action="store_true", + help="use uniform random instead of fixed sequence length", + ) + parser.set_defaults(random_sequence_length=False) + args = parser.parse_args() return args @@ -457,6 +490,8 @@ def create_and_save_test_data( segment_ids_name: Optional[str], input_mask_name: Optional[str], only_input_tensors: bool, + average_sequence_length: int, + random_sequence_length: bool, ): """Create test data for a model, and save test data to a directory. @@ -471,7 +506,9 @@ def create_and_save_test_data( input_ids_name (str): graph input name of input_ids segment_ids_name (str): graph input name of segment_ids input_mask_name (str): graph input name of input_mask - only_input_tensors (bool): only save input tensors + only_input_tensors (bool): only save input tensors, + average_sequence_length (int): average sequence length excluding paddings + random_sequence_length (bool): whether use uniform random number for sequence length """ input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name) @@ -484,7 +521,8 @@ def create_and_save_test_data( input_ids, segment_ids, input_mask, - random_mask_length=False, + average_sequence_length, + random_sequence_length, ) for i, inputs in enumerate(all_inputs): @@ -511,6 +549,9 @@ def create_and_save_test_data( def main(): args = parse_arguments() + if args.average_sequence_length <= 0: + args.average_sequence_length = args.sequence_length + output_dir = args.output_dir if output_dir is None: # Default output directory is a sub-directory under the directory of model. @@ -536,6 +577,8 @@ def main(): args.segment_ids_name, args.input_mask_name, args.only_input_tensors, + args.average_sequence_length, + args.random_sequence_length, ) print("Test data is saved to directory:", output_dir) diff --git a/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py b/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py index 735d2d4899041..6ba4fac1b7551 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py +++ b/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py @@ -42,6 +42,23 @@ def parse_arguments(): help="maximum sequence length of input", ) + parser.add_argument( + "-a", + "--average_sequence_length", + default=-1, + type=int, + help="average sequence length excluding padding", + ) + + parser.add_argument( + "-r", + "--random_sequence_length", + required=False, + action="store_true", + help="use uniform random instead of fixed sequence length", + ) + parser.set_defaults(random_sequence_length=False) + parser.add_argument( "--global_tokens", required=False, @@ -190,7 +207,8 @@ def fake_test_data( input_mask, global_mask, num_global_tokens, - random_mask_length=False, + average_sequence_length, + random_sequence_length, ): """ Generate fake input data for test. @@ -206,7 +224,9 @@ def fake_test_data( inputs = {input_ids.name: input_1} if input_mask: - inputs[input_mask.name] = fake_input_mask_data(input_mask, batch_size, sequence_length, random_mask_length) + inputs[input_mask.name] = fake_input_mask_data( + input_mask, batch_size, sequence_length, average_sequence_length, random_sequence_length + ) if global_mask: inputs[global_mask.name] = fake_global_mask_data( @@ -230,7 +250,8 @@ def generate_test_data( input_mask, global_mask, num_global_tokens, - random_mask_length=False, + average_sequence_length, + random_sequence_length, ): dictionary_size = 10000 all_inputs = fake_test_data( @@ -244,7 +265,8 @@ def generate_test_data( input_mask, global_mask, num_global_tokens, - random_mask_length, + average_sequence_length, + random_sequence_length, ) if len(all_inputs) != test_cases: print("Failed to create test data for test.") @@ -263,6 +285,8 @@ def create_longformer_test_data( input_mask_name, global_mask_name, num_global_tokens, + average_sequence_length, + random_sequence_length, ): input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name) all_inputs = generate_test_data( @@ -275,6 +299,8 @@ def create_longformer_test_data( input_mask, global_mask, num_global_tokens, + average_sequence_length, + random_sequence_length, ) for i, inputs in enumerate(all_inputs): @@ -299,6 +325,9 @@ def main(): else: print("Directory existed. test data files will be overwritten.") + if args.average_sequence_length <= 0: + args.average_sequence_length = args.sequence_length + create_longformer_test_data( args.model, output_dir, @@ -311,6 +340,7 @@ def main(): args.input_mask_name, args.global_mask_name, args.global_tokens, + args.average_sequence_length, ) print("Test data is saved to directory:", output_dir)