Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamic parallel dims in GEMMs #274

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,16 @@ def _get_const(val):
_enforce_non_rational(lhs, term)
res = arith_d.andi(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.Max():
rhs = stack.pop()
lhs = stack.pop()
_enforce_non_rational(rhs, term)
_enforce_non_rational(lhs, term)
if _is_integer_like_type(rhs.type):
res = arith_d.maxsi(*_broadcast(lhs, rhs))
else:
res = arith_d.maximumf(*_broadcast(lhs, rhs))
stack.append(res)
case sympy.logic.boolalg.BooleanFalse():
res = arith_d.constant(IntegerType.get_signless(1), 0)
stack.append(res)
Expand Down Expand Up @@ -1062,7 +1072,10 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node):

# For now, we assume that dimensions that have tiling constraints on them,
# do not have any other constraints.
end = arith_d.constant(IndexType.get(), int(node.count))
if isinstance(node.count, sympy.Expr):
end = gen_sympy_index(add_emitter_subs(emitter), node.count)
else:
end = arith_d.constant(IndexType.get(), int(node.count))

# Since we divide the end by the tile size, we need to make sure that the
# step is 1.
Expand Down
27 changes: 20 additions & 7 deletions iree/turbine/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..utils import graph_copy, erase_graph, get_tiling_constraint, subs_idxc
import torch.fx as fx
from ....support.logging import get_logger
import math
import sympy

logger = get_logger("turbine.wave.scheduling.schedule")

Expand Down Expand Up @@ -92,12 +92,20 @@ def schedule_reduction(
# to have atleast N iterations of the loop where N > num_stages - 1 (because
# we will be peeling off num_stages iterations from the loop).
tiling_constraint = get_tiling_constraint(reduction, constraints)
max_induction_variable = int(
subs_idxc(tiling_constraint.dim) // subs_idxc(tiling_constraint.tile_size)
max_induction_variable = subs_idxc(tiling_constraint.dim) // subs_idxc(
tiling_constraint.tile_size
)
if max_induction_variable <= scheduler.num_stages - 1:
logger.warn("Not enough iterations to pipeline the loop. Skipping pipelining.")
return {}

ivar_is_number = max_induction_variable.is_number
if ivar_is_number:
# We can only do a compile-time check if the induction variable
# is not dynamic.
max_induction_variable = int(max_induction_variable)
if max_induction_variable <= scheduler.num_stages - 1:
logger.warn(
"Not enough iterations to pipeline the loop. Skipping pipelining."
)
return {}

new_reduction = construct_pipelined_loop(
trace,
Expand All @@ -112,7 +120,12 @@ def schedule_reduction(
)

# Update new reduction count.
new_reduction.count = max_induction_variable - (scheduler.num_stages - 1)
if ivar_is_number:
new_reduction.count = max_induction_variable - (scheduler.num_stages - 1)
else:
new_reduction.count = sympy.Max(
0, max_induction_variable - (scheduler.num_stages - 1)
)


def schedule_graph(
Expand Down
10 changes: 5 additions & 5 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,6 @@ def _trace_and_get_kernel_signature(
# Partition strided operators.
partition_strided_operators(graph, self.constraints)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
align_index_sizes(graph, self.constraints)

# Decompose reduce Ops.
decompose_reduce_ops(graph, self.constraints, idxc.subs)

Expand All @@ -278,6 +273,11 @@ def _trace_and_get_kernel_signature(
use_scheduling_barriers = kwargs.get("use_scheduling_barriers", False)
schedule_graph(graph, self.constraints, use_scheduling_barriers)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
# so it should be called close to the end of pipeline.
align_index_sizes(graph, self.constraints)

# Add shared memory barriers.
add_shared_memory_barriers(graph)

Expand Down
96 changes: 96 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,102 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-8: amdgpu.mfma


@run_test
def test_dynamic_gemm_pipelined():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(2, 2, 1),
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

@tkw.wave(constraints)
def dynamic_gemm_pipelined(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD)
b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

with tk.gen.TestLaunchContext(
{
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
VALU_DELAY: 1,
SHUFFLE_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
VALU_UNITS: 8,
SHUFFLE_UNITS: 8,
},
canonicalize=True,
schedule=True,
use_scheduling_barriers=True,
dynamic_symbols=(M, N, K),
dynamic_symbols_map={M: 64, N: 128, K: 32},
):
a = torch.randn(64, 32, dtype=torch.float16)
b = torch.randn(128, 32, dtype=torch.float16)
c = torch.zeros(64, 128, dtype=torch.float32)
print(dynamic_gemm_pipelined(a, b, c).module_op)

# CHECK: func.func @dynamic_gemm_pipelined
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-1: arith.maxsi
# CHECK-COUNT-1: scf.for
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-2: vector.maskedload
# CHECK-COUNT-3: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-1: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.store
# CHECK-COUNT-2: llvm.call_intrinsic "llvm.amdgcn.sched.group.barrier"
# CHECK-COUNT-1: scf.yield
# CHECK-COUNT-4: amdgpu.mfma
# CHECK-COUNT-1: amdgpu.lds_barrier
# CHECK-COUNT-8: vector.load
# CHECK-COUNT-8: amdgpu.mfma


# This test is used to check three things
# 1. Reduction with multiple different types(MMA, ReduceOp) of iterArg works
# 2. ReduceOp lowering works using constraints from MMA (not just vector_shape).
Expand Down
23 changes: 22 additions & 1 deletion tests/kernel/wave/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import json
from torch.testing import assert_close
from enum import Enum

_run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0))
require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled")
Expand Down Expand Up @@ -60,6 +61,7 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_gemm"))
@pytest.mark.parametrize("enable_scheduling", [False, True])
@pytest.mark.parametrize("dynamic_dims", [False, True])
@pytest.mark.parametrize(
"mfma_variant",
[
Expand All @@ -68,7 +70,11 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]:
],
)
def testGemm(
shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request
shape: tuple[int],
enable_scheduling: bool,
dynamic_dims: bool,
mfma_variant: MMAType,
request,
):
run_bench = request.config.getoption("--runperf")
dump_perf = request.config.getoption("--dump-perf-files-path")
Expand Down Expand Up @@ -161,6 +167,19 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
dump_perf, "tk_" + perf_filename
)

dynamic_symbols = []
dynamic_symbols_map = {}
if dynamic_dims:
dynamic_symbols_map[M] = hyperparams[M]
dynamic_symbols_map[N] = hyperparams[N]
dynamic_symbols_map[K] = hyperparams[K]
dynamic_symbols.append(M)
dynamic_symbols.append(N)
dynamic_symbols.append(K)
del hyperparams[M]
del hyperparams[N]
del hyperparams[K]

with tk.gen.TestLaunchContext(
hyperparams,
canonicalize=True,
Expand All @@ -169,6 +188,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
run_config=config,
schedule=enable_scheduling,
use_scheduling_barriers=enable_scheduling_barriers,
dynamic_symbols=dynamic_symbols,
dynamic_symbols_map=dynamic_symbols_map,
):
a = torch.randn(shape[0], shape[2], dtype=torch.float16)
b = torch.randn(shape[1], shape[2], dtype=torch.float16)
Expand Down
Loading