Skip to content

Commit

Permalink
[AMD] Enable mixed precision matmul test (#5177)
Browse files Browse the repository at this point in the history
This commit enables mixed precision matmul test
for AMD backend. For FP8 E4M3, we test
`fp8e4m3fnuz` given that's natively supported on
MI300 series.
  • Loading branch information
makslevental authored Nov 16, 2024
1 parent 9aa114a commit 6958807
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@

import triton
import triton.language as tl
from triton._internal_testing import is_hip_mi300, is_cuda

input_dtypes = ["float16", "float32", "float64"]
if triton.runtime.driver.active.get_current_target().backend == "cuda":
if is_cuda():
input_dtypes += ["int8", "float8_e5m2"]
cc = torch.cuda.get_device_capability(0)
if cc >= (8, 9):
input_dtypes += ["float8_e4m3fn"]
elif is_hip_mi300():
input_dtypes += [
"int8",
"float8_e5m2",
# natively supported on mi300 (see CDNA3 ISA, section 7.2)
"float8_e4m3fnuz",
]

out_dtypes = ["float16", "float32"]


Expand Down Expand Up @@ -85,7 +94,7 @@ def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype):
def init_tensor(dtype, shape):
if dtype == torch.int8:
return torch.randint(0, 2, shape, device=device, dtype=dtype)
elif dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
elif dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2):
return torch.randn(shape, device=device, dtype=torch.float16).to(dtype)
else:
return torch.randn(shape, device=device, dtype=dtype)
Expand Down

0 comments on commit 6958807

Please sign in to comment.