Skip to content

Commit

Permalink
Add partition for GPR num offset
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Nov 14, 2024
1 parent 421a3f6 commit efabeae
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 100 deletions.
10 changes: 10 additions & 0 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,16 @@ class ExtractSlice(CustomOp):
def type(self) -> "Register":
return get_custom(self.register_).type

@property
def rank(self) -> int:
offset_rank = len(self.offset)
size_rank = len(self.size)
stride_rank = len(self.stride)
assert (
offset_rank == size_rank == stride_rank
), "Expected offset, size, and stride to have same rank."
return size_rank


@define_op("broadcast")
@dataclass
Expand Down
173 changes: 73 additions & 100 deletions iree/turbine/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,9 @@ def has_strided_access(node: fx.Node) -> bool:
"""
custom = get_custom(node)
if isinstance(custom, Write):
strides = [
simplify_index(custom.register_index[dim]).stride
for dim in custom.register_index
]
strides = [simplify_index(custom.index[dim]).stride for dim in custom.index]
elements_per_thread = [
simplify_index(custom.register_index[dim]).size
for dim in custom.register_index
simplify_index(custom.index[dim]).size for dim in custom.index
]
strides = [x for x, y in zip(strides, elements_per_thread) if y > 1]
num_strided_accesses = sum(1 for stride in strides if stride > 1)
Expand All @@ -85,7 +81,7 @@ def has_strided_access(node: fx.Node) -> bool:
for operator in strided_operators:
custom = get_custom(operator)
simplified_index = {
dim: simplify_index(custom.register_index.get(dim, custom.index[dim]))
dim: simplify_index(custom.index.get(dim, custom.index[dim]))
for dim in custom.index
}

Expand All @@ -101,19 +97,10 @@ def has_strided_access(node: fx.Node) -> bool:
for i in range(elements_per_thread):
# Non-contiguous access patterns can have varying offsets. We
# handle that here.
gpr_offset = [
expr
for expr in simplified_index[max_stride_dim].start.args
if expr.has(GPR_NUM)
]
if not gpr_offset:
gpr_offset = i
else:
gpr_offset = sympy.Add(*gpr_offset).subs({GPR_NUM: i})
extract = ExtractSlice(custom.register_, [i], [1], [1]).add_to_graph(
custom.graph
)
offset = np.unravel_index(int(gpr_offset * max_stride), shape)
offset = np.unravel_index(int(i * max_stride), shape)
write = Write(
extract,
custom.memory,
Expand All @@ -130,9 +117,7 @@ def has_strided_access(node: fx.Node) -> bool:
custom.graph.erase_node(operator)


def decompose_reads_with_gpr_offsets(
trace: CapturedTrace, constraints: list[Constraint]
):
def partition_ops_with_gpr_offsets(trace: CapturedTrace, constraints: list[Constraint]):
"""
This function analyzes the index sequence of operators in the graph
that are writes on 2d tensors. If the operator has an access pattern where
Expand All @@ -141,98 +126,86 @@ def decompose_reads_with_gpr_offsets(
each individual element.
"""

def has_strided_access(node: fx.Node) -> bool:
def has_gpr_offsets(node: fx.Node) -> bool:
"""
Checks for writes on 2d tensors with strided access on a single dimension that
read more than a single element.
"""
custom = get_custom(node)
if isinstance(custom, Read):
strides = [custom.index[dim].stride for dim in custom.index]
elements_per_thread = [custom.index[dim].size for dim in custom.index]
strides = [x for x, y in zip(strides, elements_per_thread) if y > 1]
num_strided_accesses = sum(1 for stride in strides if stride > 1)
if num_strided_accesses > 1:
raise NotImplementedError(
"Support for strided accesses on more than one dimension not implemented yet!"
)
return num_strided_accesses == 1
elif isinstance(custom, Write):
strides = [
simplify_index(custom.register_index[dim]).stride
for dim in custom.register_index
]
elements_per_thread = [
simplify_index(custom.register_index[dim]).size
for dim in custom.register_index
]
strides = [x for x, y in zip(strides, elements_per_thread) if y > 1]
num_strided_accesses = sum(1 for stride in strides if stride > 1)
if num_strided_accesses > 1:
raise NotImplementedError(
"Support for strided accesses on more than one dimension not implemented yet!"
)
return num_strided_accesses == 1
return False
if not isinstance(custom, (Read, Write)):
return False
dims_with_gpr_offset = [
v.start for k, v in custom.index.items() if v.start.has(GPR_NUM)
]
if not dims_with_gpr_offset:
return False
num_dims_with_gpr_offsets = len(dims_with_gpr_offset)
if num_dims_with_gpr_offsets > 1:
raise NotImplementedError("Currently only handle 1 dim with gpr offset.")
return True

strided_operators = trace.walk(has_strided_access)
hw_constraint = [c for c in constraints if isinstance(c, HardwareConstraint)][0]
strided_operators = trace.walk(has_gpr_offsets)
for operator in strided_operators:
custom = get_custom(operator)
if isinstance(custom, Read):
simplified_index = {
dim: simplify_index(custom.index.get(dim, custom.index[dim]))
for dim in custom.index
}
elif isinstance(custom, Write):
simplified_index = {
dim: simplify_index(custom.register_index.get(dim, custom.index[dim]))
for dim in custom.index
}
else:
raise NotImplementedError(
"Expected strided operator to only be read or write."
)

shape = get_vector_shape(
custom.vector_shapes, custom.register_type.symbolic_shape
)
simplified_index = {
dim: simplify_index(custom.register_index.get(dim, custom.index[dim]))
for dim in custom.index
}
elements_per_thread = subs_idxc(custom.elements_per_thread)
max_stride_dim, max_stride = max(
[(dim, seq.stride) for dim, seq in simplified_index.items()],
key=lambda item: item[1],
)
gpr_offsets = [
v.start for k, v in simplified_index.items() if v.start.has(GPR_NUM)
]
assert len(gpr_offsets) == 1, "Expected only 1-Dim has gpr offsets"
gpr_offset_expr = gpr_offsets[0]
gpr_cur_base_offset = gpr_offset_expr.subs({GPR_NUM: 0})
cur_elem_id = 0
with custom.graph.inserting_before(operator):
for i in range(elements_per_thread):
# Non-contiguous access patterns can have varying offsets. We
# handle that here.
gpr_offset = [
expr
for expr in simplified_index[max_stride_dim].start.args
if expr.has(GPR_NUM)
]
if not gpr_offset:
gpr_offset = i
else:
gpr_offset = sympy.Add(*gpr_offset).subs({GPR_NUM: i})
extract = ExtractSlice(custom.register_, [i], [1], [1]).add_to_graph(
custom.graph
)
offset = np.unravel_index(int(gpr_offset * max_stride), shape)
write = Write(
extract,
custom.memory,
mapping=custom.mapping,
elements_per_thread=1,
).add_to_graph(custom.graph)
write.index = {
dim: IndexSequence(
simplified_index[dim].start.subs({GPR_NUM: 0}) + offset[j], 1, 1
# Break apart Reads/Writes that has non-contiguous GPR Read/Writes.
next_gpr_offset = gpr_offset_expr.subs({GPR_NUM: i + 1})
cur_gpr_offset = gpr_offset_expr.subs({GPR_NUM: i})
gpr_offset_step = next_gpr_offset - cur_gpr_offset
if not isinstance(gpr_offset_step, sympy.Integer):
raise NotImplementedError(
"Only constant integer GPR offset steps supported."
)
for j, dim in enumerate(custom.register_type.symbolic_shape)
}

custom.graph.erase_node(operator)
gpr_offset_step = int(gpr_offset_step)

# Create new write when there is a jump in GPR offset
# or at the end of the loop.
if gpr_offset_step > 1 or i == elements_per_thread - 1:
# Get VGPR number of elements.
gpr_size = (cur_gpr_offset - gpr_cur_base_offset) + 1
assert isinstance(
gpr_size, sympy.Integer
), "Expected gpr_size to be int."
gpr_size = int(gpr_size)

# Generate new Read/Write that has contiguous VGPR elements.
extract = ExtractSlice(
custom.register_, [cur_elem_id], [gpr_size], [1]
).add_to_graph(custom.graph)
write = Write(
extract,
custom.memory,
mapping=custom.mapping,
elements_per_thread=gpr_size,
).add_to_graph(custom.graph)
write.index = {
dim: IndexSequence(
simplified_index[dim].start.subs({GPR_NUM: cur_elem_id}),
gpr_size,
simplified_index[dim].stride,
)
for dim in simplified_index
}
write.vector_shapes = custom.vector_shapes

# Set new current base GPR offset
gpr_cur_base_offset = next_gpr_offset
cur_elem_id = i + 1
if isinstance(custom, Write):
custom.graph.erase_node(operator)


def preprocess_nodes(
Expand Down
26 changes: 26 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CustomOp,
Reduction,
GetResult,
ExtractSlice,
IterArg,
Reshape,
)
Expand Down Expand Up @@ -168,6 +169,31 @@ def is_chained_getresult(node: fx.Node) -> bool:
get_custom(node).graph.erase_node(node)


def remove_chained_extractslice(trace: CapturedTrace):
def is_chained_extractslice(node: fx.Node) -> bool:
custom = get_custom(node)
if not isinstance(custom, ExtractSlice):
return False
register = get_custom(custom.register_)
if not isinstance(register, ExtractSlice):
return False
return custom.rank == register.rank

while removable_nodes := trace.walk(is_chained_extractslice):
for node in removable_nodes:
dst_extract = get_custom(node)
src_extract = get_custom(dst_extract.register_)
dst_extract.register_ = src_extract.register_
new_offset = [
dst_i + src_i
for dst_i, src_i in zip(dst_extract.offset, src_extract.offset)
]
dst_extract.update_arg("register_", src_extract.register_)
dst_extract.update_arg("offset", new_offset)
if len(src_extract.fx_node.users) == 0:
get_custom(node).graph.erase_node(src_extract.fx_node)


def delinearize_index(index: IndexExpr, shape: list[int]) -> list[IndexExpr]:
"""
Delinearizes a 1D index into a multi-dimensional index
Expand Down
4 changes: 4 additions & 0 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
compile_and_invoke,
safe_subs,
remove_chained_getresult,
remove_chained_extractslice,
subs_idxc,
)
from .minimize_global_loads import minimize_global_loads
Expand All @@ -38,6 +39,7 @@
from ..ops import wave_ops
from ..ops.wave_ops import Reduction, CustomOp, get_custom
from .index_sequence_analysis import (
partition_ops_with_gpr_offsets,
partition_strided_operators,
set_node_indices,
set_post_expansion_indices,
Expand Down Expand Up @@ -254,7 +256,9 @@ def _trace_and_get_kernel_signature(
apply_shared_memory_indexing_corrections(graph, self.constraints)

# Partition strided operators.
partition_ops_with_gpr_offsets(graph, self.constraints)
partition_strided_operators(graph, self.constraints)
remove_chained_extractslice(graph)

# Align sizes to WG/Tile sizes
# This pass changes indexing keys, which can interfere with other passes,
Expand Down

0 comments on commit efabeae

Please sign in to comment.