Skip to content

Commit

Permalink
Revert "Reland D62220158 (pytorch#136213)"
Browse files Browse the repository at this point in the history
This reverts commit 083c914.

Reverted pytorch#136213 on behalf of https://github.com/jeanschmidt due to Seems to have introduced regressions in rocm signals ([comment](pytorch#136213 (comment)))
  • Loading branch information
pytorchmergebot committed Sep 19, 2024
1 parent bce52d0 commit 4ea741d
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 59 deletions.
35 changes: 0 additions & 35 deletions test/inductor/test_pad_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
get_pad_cache,
get_padded_length,
should_pad_common,
should_pad_mm_bf16,
)
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code
Expand Down Expand Up @@ -450,40 +449,6 @@ def mm(inps, b):
repr(get_pad_cache().get_local_cache())
)

@fresh_inductor_cache()
@inductor_config.patch(
post_grad_fusion_options={"pad_aten_mm_pass": {"k_threshold_to_pad": 8388608}}
)
def test_pad_mm_bf16(self):
m = 2
n = 13
k = 15691904
mat1 = torch.ones((m, k), device="cuda", dtype=torch.bfloat16)
mat2 = torch.ones((k, n), device="cuda", dtype=torch.bfloat16)
expected_alignment = get_alignment_size(mat1)

assert expected_alignment == 8, "Alignment for bfloat16 should be 8"
assert should_pad_common(
mat1, mat2
), "This should pass the common padding criteria"
if torch.cuda.get_device_capability() < (9, 0):
assert should_pad_mm_bf16(
mat1.dtype, m, n, k
), "This should pass the should_pad_mm_bf16 padding criteria"

@torch.compile()
def mm(mat1, mat2):
return torch.mm(mat1, mat2)

res2, (code,) = run_and_get_code(mm, mat1, mat2)
mm_expected_result = torch.mm(mat1, mat2)
# in call code, expect to see a single pad per input, and then we should see padded allocation for output
FileCheck().check("del async_compile").check_count(
".run(", 2, exactly=True
).check("empty_strided_cuda((8, 16)").run(code)

assert torch.allclose(res2, mm_expected_result), "MM results are not identical"


if __name__ == "__main__":
if HAS_CUDA:
Expand Down
23 changes: 0 additions & 23 deletions torch/_inductor/fx_passes/pad_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,23 +364,6 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
return should_pad


def should_pad_mm_bf16(dtype, M, N, K):
# always force pad for mm with bf16 when the following are satisfied to avoid perf regression
large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
"pad_aten_mm_pass"
].get("k_threshold_to_pad", 8388608)
if (
dtype is torch.bfloat16
and K > M
and K > N
and N % 2 == 1
and K >= large_k_threshold_to_pad
and torch.cuda.get_device_capability() < (9, 0)
): # doesnt repro on h100s:
return True
return False


def should_pad_bench(
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
) -> bool:
Expand Down Expand Up @@ -427,12 +410,6 @@ def realize_symbols(ds):
if torch._inductor.config.force_shape_pad:
return True

if (
"pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options
and should_pad_mm_bf16(mat1.dtype, m, n, k)
):
return True

if not has_triton():
return False

Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/fx_passes/split_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
"decompose_mm_pass",
"unbind_stack_aten_pass",
"shape_padding_multiplier",
"pad_aten_mm_pass",
]

for pass_name in pre_grad_pass_names:
Expand Down

0 comments on commit 4ea741d

Please sign in to comment.