diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index 0b34e174cc668..dbe1170994fdb 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -9,7 +9,6 @@ get_pad_cache, get_padded_length, should_pad_common, - should_pad_mm_bf16, ) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code @@ -450,40 +449,6 @@ def mm(inps, b): repr(get_pad_cache().get_local_cache()) ) - @fresh_inductor_cache() - @inductor_config.patch( - post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}} - ) - def test_pad_mm_bf16(self): - m = 2 - n = 13 - k = 15691904 - mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16) - mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16) - expected_alignment = get_alignment_size(mat1) - - assert expected_alignment == 8, "Alignment for bfloat16 should be 8" - assert should_pad_common( - mat1, mat2 - ), "This should pass the common padding criteria" - if torch.cuda.get_device_capability() < (9, 0): - assert should_pad_mm_bf16( - mat1.dtype, m, n, k - ), "This should pass the should_pad_mm_bf16 padding criteria" - - @torch.compile() - def mm(mat1, mat2): - return torch.mm(mat1, mat2) - - res2, (code,) = run_and_get_code(mm, mat1, mat2) - mm_expected_result = torch.mm(mat1, mat2) - # in call code, expect to see a single pad per input, and then we should see padded allocation for output - FileCheck().check("del async_compile").check_count( - ".run(", 2, exactly=True - ).check("empty_strided_cuda((8, 16)").run(code) - - assert torch.allclose(res2, mm_expected_result), "MM results are not identical" - if __name__ == "__main__": if HAS_CUDA: diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index a00bad3974791..87450b34e7ee2 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -364,23 +364,6 @@ def should_pad(key: str, ori_time, pad_time) -> bool: return should_pad -def should_pad_mm_bf16(dtype, M, N, K): - # always force pad for mm with bf16 when the following are satisfied to avoid perf regression - large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ - "pad_aten_mm_pass" - ].get("k_threshold_to_pad", 8388608) - if ( - dtype is torch.bfloat16 - and K > M - and K > N - and N % 2 == 1 - and K >= large_k_threshold_to_pad - and torch.cuda.get_device_capability() < (9, 0) - ): # doesnt repro on h100s: - return True - return False - - def should_pad_bench( match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None ) -> bool: @@ -427,12 +410,6 @@ def realize_symbols(ds): if torch._inductor.config.force_shape_pad: return True - if ( - "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options - and should_pad_mm_bf16(mat1.dtype, m, n, k) - ): - return True - if not has_triton(): return False diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 194d1d6dbaa79..f850ecf6008c9 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -65,7 +65,6 @@ "decompose_mm_pass", "unbind_stack_aten_pass", "shape_padding_multiplier", - "pad_aten_mm_pass", ] for pass_name in pre_grad_pass_names: