Skip to content

Commit

Permalink
Merge branch 'fix_rdna-arch' of github.com:Cunxiao2002/BitBLAS into f…
Browse files Browse the repository at this point in the history
…ix_rdna-arch
  • Loading branch information
Cunxiao2002 committed Oct 22, 2024
2 parents 6d9f15d + 5f2bddc commit 77f0f4d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
13 changes: 10 additions & 3 deletions testing/python/tilelang/test_tilelang_dyanmic_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,18 @@ def assert_tl_matmul_block_all_dynamic_correctness(
)
mod, params = TL.lower(program)

A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
if trans_B:
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
else:
B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))

mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
print(mod.mod.imported_modules[0].get_source())

def ref_program(A, B):
import torch
Expand Down Expand Up @@ -419,6 +424,8 @@ def test_assert_tl_matmul_block_all_dynamic():
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16",
"float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 115, 103, False, False, "float16", "float16",
"float16", 64, 64, 32)


if __name__ == "__main__":
Expand Down

0 comments on commit 77f0f4d

Please sign in to comment.