Skip to content

Commit

Permalink
Dynamic offsets API
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Nov 14, 2024
1 parent 01cf5b7 commit 37a48fb
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 3 deletions.
8 changes: 7 additions & 1 deletion iree/turbine/kernel/lang/wave_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def __init__(
self.iteration_shape = iter_shape
self.input_mapping = inputs
self.output_mapping = outputs
if not isinstance(dynamic_val_mappings, (list, tuple)):
if dynamic_val_mappings is None:
dynamic_val_mappings = ()
elif not isinstance(dynamic_val_mappings, (list, tuple)):
dynamic_val_mappings = (dynamic_val_mappings,)

self.dynamic_val_mappings = tuple(dynamic_val_mappings)
Expand All @@ -210,6 +212,10 @@ def __init__(
def num_iterators(self) -> int:
return len(self.iters)

@property
def num_dynamic_vals(self) -> int:
return len(self.dynamic_vals)

def substitute(self, subs: Iterable[tuple[IndexExpr, IndexExpr]]) -> Self:
new_inputs = {
key: _subs_expr(val, subs) for key, val in self.input_mapping.items()
Expand Down
4 changes: 4 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def read(
memory: "Memory",
elements_per_thread: Optional[IndexExpr | int] = None,
mapping: Optional[IndexMapping] = None,
mapping_dynamic_vals: Optional["Register" | tuple["Register", ...]] = None,
) -> "Register":
...

Expand All @@ -92,6 +93,7 @@ def write(
memory: "Memory",
elements_per_thread: Optional[IndexExpr | int] = None,
mapping: Optional[IndexMapping] = None,
mapping_dynamic_vals: Optional["Register" | tuple["Register", ...]] = None,
):
...

Expand Down Expand Up @@ -946,6 +948,7 @@ class Read(CustomOp):
memory: fx.Proxy
elements_per_thread: Optional[Any] = None
mapping: Optional[IndexMapping] = None
mapping_dynamic_vals: tuple["Register", ...] = ((),)
_write_dependency: Optional[list[fx.Node]] = None

@property
Expand Down Expand Up @@ -1120,6 +1123,7 @@ class Write(CustomOp):
memory: fx.Proxy
elements_per_thread: Optional[Any]
mapping: Optional[IndexMapping] = None
mapping_dynamic_vals: tuple["Register", ...] = ((),)

@property
def indexing_dims(self) -> list[IndexSymbol]:
Expand Down
8 changes: 6 additions & 2 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,10 +682,12 @@ def _construct_gather_scatter_indices(
def handle_read(emitter: WaveEmitter, node: fx.Node):
# This is similar to tkl.store with fixed start indices for now.
try:
memory, elements_per_thread, mapping, _ = node.args
memory, elements_per_thread, mapping, dyn_vals, _ = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

assert len(dyn_vals) == 0, "Dynamic vals are not implemented yet"

vector_shape = cast_py_literal(emitter, (elements_per_thread,))
# memory has no IR node yet.
kb_src, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, memory)
Expand Down Expand Up @@ -737,10 +739,12 @@ def handle_read(emitter: WaveEmitter, node: fx.Node):
@handle_op(write)
def handle_write(emitter: WaveEmitter, node: fx.Node):
try:
register, memory, elements_per_thread, mapping = node.args
register, memory, elements_per_thread, mapping, dyn_vals = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

assert len(dyn_vals) == 0, "Dynamic vals are not implemented yet"

# memory has no IR node yet.
kb_dest, kb_ir_type, kb_py_type = cast_kernel_buffer(emitter, memory)
insert_vector = cast_vector(emitter, register, element_type=kb_ir_type.element_type)
Expand Down
76 changes: 76 additions & 0 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,82 @@ def test(
assert_allclose(a.T, b)


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_copy"))
def test_offset_read(shape, request):
run_bench = request.config.getoption("--runperf")
M = tkl.sym.M
N = tkl.sym.N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

# Each workgroup works on single row of input data, and rows are further
# split into blocks of size up to 256. We have single wave per WG,
# and with default wave size of 64, each thread is operating on up to 4
# elements.
wave_size = 64
BLOCK_M = 1
# Tile size cannot be dynamic, so we use a fixed value here.
BLOCK_N = sympy.Max(sympy.Min(shape[1], 256), wave_size)
ELEMS_PER_THREAD = BLOCK_N / wave_size

constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=wave_size,
waves_per_block=(1, 1, 1),
vector_shapes={M: BLOCK_M, N: BLOCK_N},
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M)]
constraints += [tkw.WaveConstraint(N, BLOCK_N)]

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)
k = tkw.IndexMapping.dynamic_val(0)
mapping = tkw.IndexMapping(
num_iterators=2,
inputs={M: (i + k) % M, N: j},
outputs={M: i, N: j},
dynamic_val_mappings={M: i, N: j},
)

@tkw.wave(constraints)
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
off: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
offset = tkw.read(off, elements_per_thread=ELEMS_PER_THREAD)
res = tkw.read(
a,
mapping=mapping,
mapping_dynamic_vals=offset,
elements_per_thread=ELEMS_PER_THREAD,
)
tkw.write(res, b, elements_per_thread=ELEMS_PER_THREAD)

config = get_default_run_config()

a = torch.randn(shape, dtype=torch.float16)
off = torch.randint(10, shape, dtype=torch.int32)
b = torch.zeros(shape, dtype=torch.float16)
with tk.gen.TestLaunchContext(
{
M: shape[0],
N: shape[1],
ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
},
canonicalize=True,
run=True,
run_bench=run_bench,
run_config=config,
):
test(a, b, off)
# TODO: check
assert_allclose(a, b)


@require_e2e
@pytest.mark.parametrize("shape", get_test_shapes("test_reduce_sum"))
def test_reduce_sum(shape, request):
Expand Down

0 comments on commit 37a48fb

Please sign in to comment.