From 9da92ff639fd4a111ed07f881065e15107391705 Mon Sep 17 00:00:00 2001 From: kiddyjinjin Date: Mon, 4 Nov 2024 07:38:24 +0000 Subject: [PATCH] fix repeat_interleave benchmark bug --- benchmark/conftest.py | 5 +++-- benchmark/test_tensor_concat_perf.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/benchmark/conftest.py b/benchmark/conftest.py index 7c6776dd..15a3b187 100644 --- a/benchmark/conftest.py +++ b/benchmark/conftest.py @@ -177,11 +177,12 @@ def setup_once(request): print("\nThis is query mode; all benchmark functions will be skipped.") else: note_info = ( - f"\n\nNote: The 'size' field below is for backward compatibility with previous versions of the benchmark. " - f"\nThis field will be removed in a future release." + f"\n\nNote: The 'size' field below is for backward compatibility with previous versions of the benchmark. " + f"\nThis field will be removed in a future release." ) print(note_info) + @pytest.fixture() def extract_and_log_op_attributes(request): print("") diff --git a/benchmark/test_tensor_concat_perf.py b/benchmark/test_tensor_concat_perf.py index 8ffb45e0..9e6077e4 100644 --- a/benchmark/test_tensor_concat_perf.py +++ b/benchmark/test_tensor_concat_perf.py @@ -124,8 +124,20 @@ def repeat_input_fn(shape, cur_dtype, device): def repeat_interleave_self_input_fn(shape, dtype, device): inp = generate_tensor_input(shape, dtype, device) - repeats = 3 - yield inp, repeats + repeats = torch.randint( + low=0, + high=0x2F, + size=[ + shape[0], + ], + device=device, + ) + dim = 0 + # repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + yield inp, repeats, dim + if Config.bench_level == BenchLevel.COMPREHENSIVE: + # repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + yield inp, 3 @pytest.mark.parametrize(