Skip to content

Commit

Permalink
Add type inference (#252)
Browse files Browse the repository at this point in the history
This PR adds a type inference pass to wave. Previously, the types were
infered by looking up types from neighbors resulting in inefficient type
inference.

Instead, we now introduce a pass that infers the types for all operators
in the graph and the inferred type is then stores in the node. New nodes
that are constructed in downstream passes are responsible for annotating
types for the new operators.

---------

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored Nov 6, 2024
1 parent 9970573 commit db1ec57
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 43 deletions.
94 changes: 53 additions & 41 deletions iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def custom_string(self, value_map: dict[str, str]) -> str:
vars_str = ", ".join(vars_list)
return f"{self.tkw_op_name}({vars_str})"

def add_to_graph(self, region_graph: RegionGraph) -> fx.Node:
def add_to_graph(self, region_graph: RegionGraph, type: Any = None) -> fx.Node:
arg_list = tuple([value for _, value in vars(self).items()])
self.graph = region_graph
self.fx_node = region_graph.create_node(
Expand All @@ -350,6 +350,10 @@ def add_to_graph(self, region_graph: RegionGraph) -> fx.Node:
self.fx_node.tkw_op = self.__class__
self.fx_node.tkw_op_name = self.tkw_op_name
self.fx_node.index = None
if type is None:
get_custom(self.fx_node).infer_type()
else:
self.fx_node.type = type
return self.fx_node

def _add_proxy_to_graph(self, region_graph: RegionGraph):
Expand Down Expand Up @@ -556,6 +560,23 @@ def vector_shapes(self) -> dict[IndexSymbol, int]:
def vector_shapes(self, value: dict[IndexSymbol, int]):
self.fx_node.vector_shapes = value

@property
def type(self) -> Any:
if hasattr(self.fx_node, "type"):
return self.fx_node.type
return None

@type.setter
def type(self, value: Any):
self.fx_node.type = value

def infer_type(self):
"""
Infer the type of this operator using the types
of its arguments.
"""
pass

def align_index(self, constraints: list["Constraint"]) -> None:
"""
Align index to WG/Tile sizes.
Expand Down Expand Up @@ -602,21 +623,21 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
def infer_type(self):
lhs_type = get_custom(self.lhs).type
rhs_type = get_custom(self.rhs).type
has_same_type = has_same_custom_type(lhs_type, rhs_type)
if has_same_type:
return lhs_type
self.type = lhs_type
return
lhs_dim_set = set(lhs_type.symbolic_shape)
rhs_dim_set = set(rhs_type.symbolic_shape)
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
)
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
return broadcasted_type
self.type = broadcasted_type


@define_interface_op("exp2")
Expand All @@ -637,10 +658,9 @@ def indexing_dims(self) -> list[IndexSymbol]:
def py_operator(self) -> str:
return self.tkw_op_name

@property
def type(self) -> Memory:
def infer_type(self):
src_type = get_custom(self.arg).type
return src_type
self.type = src_type


@final
Expand Down Expand Up @@ -868,9 +888,8 @@ def rhs_type(self) -> Memory:
def acc_type(self) -> Memory:
return get_custom(self.acc).type

@property
def type(self) -> Memory:
return self.acc_type
def infer_type(self):
self.type = self.acc_type

def operand_index(
self, operand_map: dict[IndexSymbol, int], shape: list[IndexExpr]
Expand Down Expand Up @@ -925,6 +944,7 @@ def reduction_dim(self, value: IndexSymbol):
@define_op("read")
@dataclass
class Read(CustomOp):

memory: fx.Proxy
elements_per_thread: Optional[Any] = None
mapping: Optional[IndexMapping] = None
Expand All @@ -937,10 +957,9 @@ def indexing_dims(self) -> list[IndexSymbol]:
# TODO: This could contain ints.
return list(self.memory_type.symbolic_shape)

@property
def type(self) -> "Register":
def infer_type(self):
dtype = self.memory_type.dtype
return Register[*self.indexing_dims, dtype]
self.type = Register[*self.indexing_dims, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1052,12 +1071,11 @@ def captured_vars(self, graph: fx.Graph) -> list[fx.Node]:
captured_vars.append(nested_node)
return captured_vars

@property
def type(self) -> Memory | Register | list[Memory | Register]:
def infer_type(self):
res_types = [get_custom(x).type for x in self.init_args]
if len(res_types) == 1:
res_types = res_types[0]
return res_types
self.type = res_types

def outputs(self, graph: fx.Graph) -> list[fx.Node]:
for node in graph.nodes:
Expand Down Expand Up @@ -1110,11 +1128,12 @@ def indexing_dims(self) -> list[IndexSymbol]:
if self.mapping is not None:
return list(self.mapping.input_shape)
# TODO: This could contain ints.
return list(self.type.symbolic_shape)
return list(self.memory_type.symbolic_shape)

@property
def type(self) -> "Memory":
return get_custom(self.memory).type
def infer_type(self):
address_space = self.memory_type.address_space
dtype = self.memory_type.dtype
self.type = Memory[*self.indexing_dims, address_space, dtype]

@property
def memory_type(self) -> "Memory":
Expand Down Expand Up @@ -1144,13 +1163,12 @@ class GetResult(CustomOp):
value: fx.Node
res_idx: int

@property
def type(self) -> "Memory":
def infer_type(self):
src_type = get_custom(self.value).type
if isinstance(src_type, list):
return src_type[self.res_idx]
self.type = src_type[self.res_idx]
else:
return src_type
self.type = src_type

@property
def indexing_dims(self) -> list[IndexExpr]:
Expand Down Expand Up @@ -1200,14 +1218,14 @@ class Extract(CustomOp):
register_: fx.Proxy
offset: IndexExpr | int

@property
def type(self) -> "Register":
def infer_type(self):
# Intuition here is we are trying to extract an element
# from fastest dim => we reduce the fastest dim.
src_type = get_custom(self.register_).type
# Return itself if just 0-D/1-D symbolic.
if len(src_type.symbolic_shape) <= 1:
return src_type
self.type = src_type
return

# Typically fastest dim is the last dimension,
# If non-unit dim exists => non-unit dim is fastest dim.
Expand All @@ -1220,7 +1238,7 @@ def type(self) -> "Register":
dim_to_remove = dst_shape[-1] if not non_unit_dim else non_unit_dim[0]
dst_shape.remove(dim_to_remove)
dst_type = Register[*dst_shape, src_type.dtype]
return dst_type
self.type = dst_type


@define_op("extract_slice")
Expand Down Expand Up @@ -1297,12 +1315,8 @@ def indexing_dims(self) -> list[IndexSymbol]:
dst_indexing = [dim for dim in src_indexing if dim != self.dim]
return dst_indexing

@property
def type(self) -> Memory:
def infer_type(self):
if isinstance(self.arg, Sequence):
# Local import to break circular dep.
from ..wave.utils import all_equal

src_types = [get_custom(arg).type for arg in self.arg]
ref_shape = src_types[0].symbolic_shape
ref_dtype = src_types[0].dtype
Expand All @@ -1318,7 +1332,7 @@ def type(self) -> Memory:
src_type = get_custom(self.arg).type
reduced_dims = [dims for dims in src_type.symbolic_shape if dims != self.dim]
dst_type = Register[*reduced_dims, src_type.dtype]
return dst_type
self.type = dst_type

@property
def num_reduction_dims(self) -> int:
Expand Down Expand Up @@ -1376,10 +1390,9 @@ class CastOp(CustomOp, ABC):
def indexing_dims(self) -> list[IndexSymbol]:
return get_custom(self.arg).indexing_dims

@property
def type(self) -> Memory:
def infer_type(self):
src_shape = get_custom(self.arg).type.symbolic_shape
return Register[*src_shape, self.dtype]
self.type = Register[*src_shape, self.dtype]


@define_op("permute")
Expand All @@ -1397,13 +1410,12 @@ class Permute(CustomOp, ABC):
def indexing_dims(self) -> list[IndexExpr]:
return self.target_shape

@property
def type(self) -> Register:
def infer_type(self):
src_type = get_custom(self.arg).type
assert set(src_type.symbolic_shape) == set(
self.target_shape
), f"Target shape {self.target_shape} must be a permutation of source shape {src_type.symbolic_shape}"
return Register[*self.target_shape, src_type.dtype]
self.type = Register[*self.target_shape, src_type.dtype]


def _to_sequence(input: Any | Sequence[Any]) -> Sequence[Any]:
Expand Down
7 changes: 6 additions & 1 deletion iree/turbine/kernel/wave/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,12 @@ def _expand_reduction(
# Add GetResult nodes for the corresponding dimensions
reduction.graph.inserting_after(reduction.fx_node)
new_node = GetResult(reduction.fx_node, len(new_output_args))
new_node.add_to_graph(reduction.graph)
# Usually we would rely on infer_types inside add_to_graph to figure out
# the type of the new node. However, in this case, the logic to determine
# the type requires the reduction node to have its init_args set, which has
# not happened yet (it happens later). So instead, since we have access to
# arg, we just set the type directly.
new_node.add_to_graph(reduction.graph, arg.type)
new_node.fx_node.name = get_expanded_name(new_node, dims)
context[
(reduction, get_indexed_dims(dims, expand_dims), arg_idx)
Expand Down
21 changes: 21 additions & 0 deletions iree/turbine/kernel/wave/type_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..ops.wave_ops import *
from .._support.tracing import CapturedTrace
import torch.fx as fx
from ...support.logging import get_logger

logger = get_logger("turbine.wave.type_inference")


def infer_types(trace: CapturedTrace | fx.Graph):
# Infer and set the types for all nodes in the graph.
for subgraph in trace.region_graph.subgraphs.values():
for node in subgraph.nodes:
custom = get_custom(node)
custom.infer_type()
logger.debug(f"Setting type for {custom.fx_node} = {custom.type}")
4 changes: 4 additions & 0 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from .thread_shape_analysis import determine_thread_shapes
from .scheduling.schedule import schedule_graph
from .._support.indexing import IndexingContext, IndexExpr
from .type_inference import infer_types
import iree.turbine.kernel.lang as tkl
from .._support.tracing import (
CapturedTrace,
Expand Down Expand Up @@ -224,6 +225,9 @@ def _trace_and_get_kernel_signature(
# Initialize Vector shapes
self.hardware_constraints[0].subs_vector_shapes(idxc.subs)

# Do type inference.
infer_types(graph)

# Promote the placeholders to the appropriate address space.
promote_placeholders(graph, self.constraints)
hoist_allocs(graph)
Expand Down
3 changes: 3 additions & 0 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_read_write_equal_sizes():
graph: fx.Graph = trace.get_root_graph()
read_node = get_read_nodes(graph)[0]
IndexingContext.current().finalize()
infer_types(trace)
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
Expand Down Expand Up @@ -171,6 +173,7 @@ def test_gemm():
trace: CapturedTrace = gemm()
graph: fx.Graph = trace.get_subgraph("region_0")
IndexingContext.current().finalize()
infer_types(trace)
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
Expand Down
Loading

0 comments on commit db1ec57

Please sign in to comment.