Skip to content

Commit

Permalink
Reuse state to simplify signatures, also more Pythonic constructs.
Browse files Browse the repository at this point in the history
  • Loading branch information
PapyChacal committed Feb 25, 2024
1 parent 2120a8e commit a49422f
Showing 1 changed file with 44 additions and 68 deletions.
112 changes: 44 additions & 68 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _convert_eq(self, eq: LoweredEq):
# an equation may look like this:
# u[x+1,y+1,z] = (u[x,y,z+1] + u[x+2,y+2,z+1]) / 2
if isinstance(eq.lhs, Symbol):
return func.FuncOp.external(eq.lhs.name, [], [builtin.i32])
return xdsl_func.FuncOp.external(eq.lhs.name, [], [builtin.i32])
assert isinstance(eq.lhs, Indexed)

# u(t, x, y)
Expand All @@ -73,30 +73,34 @@ def _convert_eq(self, eq: LoweredEq):

# Get all functions used in the equation
functions = OrderedSet(*(f.function for f in retrieve_function_carriers(eq)))
# We identify time buffers by their function and positive time offset.
# We store a list of those here to help the following steps.
self.time_buffers = [(f, i) for f in functions for i in range(f.time_size)]

# For each used function, define as many fields as its time_size
fields = [[field_from_function(f)]*f.time_size for f in functions]
# Flatten the list for function signature
flat_fields = [f for ff in fields for f in ff]
fields_types = [field_from_function(f) for (f, _) in self.time_buffers]
# Create a function with the fields as arguments
f = func.FuncOp("apply_kernel", ([*flat_fields], []))
xdsl_func = func.FuncOp("apply_kernel", (fields_types, []))
# Define nice argument names to try and stay sane while debugging
arg_names = [f"{f.name}_vec_{i}" for f in functions for i in range(f.time_size)]
for i, arg_name in enumerate(arg_names):
f.body.block.args[i].name_hint = arg_name

with ImplicitBuilder(f.body.block):
self._build_step_loop(step_dim, functions, eq)
# And store in self.function_args a mapping from time_buffers to their
# corresponding function arguments.
self.function_args = {}
for i, (f, t) in enumerate(self.time_buffers):
xdsl_func.body.block.args[i].name_hint = f"{f.name}_vec_{t}"
self.function_args[(f,t)] = xdsl_func.body.block.args[i]

with ImplicitBuilder(xdsl_func.body.block):
self._build_step_loop(step_dim, eq)
# func wants a return
func.Return()
return f
def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tuple[DiscreteFunction, int], SSAValue], output_indexed:Indexed) -> SSAValue:

return xdsl_func
def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, output_indexed:Indexed) -> SSAValue:
# Handle Indexeds
if isinstance(node, Indexed):
space_offsets = [node.indices[d] - output_indexed.indices[d] for d in node.function.space_dimensions]
# import pdb; pdb.set_trace()
temp = temps[(node.function, (node.indices[dim] - dim) % node.function.time_size)]
temp = self.apply_temps[(node.function, (node.indices[dim] - dim) % node.function.time_size)]
access = stencil.AccessOp.get(temp, space_offsets)
return access.res
# Handle Integers
Expand All @@ -113,7 +117,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tupl
return symb.result
# Handle Add Mul
elif isinstance(node, (Add, Mul)):
args = [self._visit_math_nodes(dim, arg, temps, output_indexed) for arg in node.args]
args = [self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args]
# add casts when necessary
# get first element out, store the rest in args
# this makes the reduction easier
Expand All @@ -131,7 +135,7 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tupl
return carry
# Handle Pow
elif isinstance(node, Pow):
args = [self._visit_math_nodes(dim, arg, temps, output_indexed) for arg in node.args]
args = [self._visit_math_nodes(dim, arg, output_indexed) for arg in node.args]
assert len(args) == 2, "can't pow with != 2 args!"
base, ex = args
if is_int(base):
Expand All @@ -157,55 +161,43 @@ def _visit_math_nodes(self, dim: SteppingDimension, node: Expr, temps: dict[tupl
else:
raise NotImplementedError(f"Unknown math: {node}", node)


def _build_step_body(self, dim: SteppingDimension, functions:Iterable[DiscreteFunction], eq:LoweredEq) -> None:
iter_args = ImplicitBuilder.get().insertion_point.block.args[1:]

def _build_step_body(self, dim: SteppingDimension, eq:LoweredEq) -> None:
output_function = eq.lhs.function
output_time_offset = (eq.lhs.indices[dim] - dim) % eq.lhs.function.time_size

function_fields : dict[tuple[DiscreteFunction, int], SSAValue] = {}
handled_args = 0
for f in functions:
for i in range(f.time_size):
function_fields[(f, i)] = iter_args[handled_args + i]
handled_args += f.time_size

function_temps : dict[tuple[DiscreteFunction, int], SSAValue] = {}
for (function, time_offset), field in function_fields.items():
# import pdb; pdb.set_trace()
if (function, time_offset) == (output_function, output_time_offset):
continue
load = stencil.LoadOp.get(field)
load.res.name_hint = f"{function.name}_t{time_offset}_temp"
loop_temps = {
(f, t): stencil.LoadOp.get(a).res
for (f, t), a in self.block_args.items()
if (f, t) != (output_function, output_time_offset)
}
for (f,t), a in loop_temps.items():
a.name_hint = f"{f.name}_t{t}_temp"

function_temps[(function, time_offset)] = load.res

shape = output_function.grid.shape_local
apply = stencil.ApplyOp.get(
function_temps.values(),
Block(arg_types=[a.type for a in function_temps.values()]),
loop_temps.values(),
Block(arg_types=[a.type for a in loop_temps.values()]),
result_types=[stencil.TempType(len(shape), element_type=dtypes_to_xdsltypes[output_function.dtype])]
)
for apply_arg, apply_op in zip(apply.region.block.args, apply.operands):
apply_arg.name_hint = apply_op.name_hint[:-4]+"_blk"

apply_temps = {k:v for k,v in zip(function_temps.keys(), apply.region.block.args)}
self.apply_temps = {k:v for k,v in zip(loop_temps.keys(), apply.region.block.args)}

with ImplicitBuilder(apply.region.block):
stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, apply_temps, eq.lhs)])
stencil.ReturnOp.get([self._visit_math_nodes(dim, eq.rhs, eq.lhs)])
# TODO Think about multiple outputs
stencil.StoreOp.get(
apply.res[0],
function_fields[output_function, output_time_offset],
self.block_args[output_function, output_time_offset],
stencil.IndexAttr.get(*([0] * len(shape))),
stencil.IndexAttr.get(*shape),
)

def _build_step_loop(
self,
dim: SteppingDimension,
functions: Iterable[DiscreteFunction],
eq: LoweredEq,
) -> scf.For:
# Bounds and step boilerpalte
Expand All @@ -220,34 +212,19 @@ def _build_step_loop(
except:
raise ValueError("step must be int!")

# Get the function arguments for iteration arguments
func_args = ImplicitBuilder.get().insertion_point.block.args

iter_args = list(self.function_args.values())
# Create the for loop
loop = scf.For(lb, arith.Addi(ub, one), step, func_args, Block(arg_types=[builtin.IndexType()] + [a.type for a in func_args]))
loop = scf.For(lb, arith.Addi(ub, one), step, iter_args, Block(arg_types=[builtin.IndexType(), *(a.type for a in iter_args)]))
loop.body.block.args[0].name_hint = "time"
handled_args = 0
for f in functions:
for i in range(f.time_size):
loop.body.block.args[1+handled_args+i].name_hint = f"{f.name}_t{i}"
handled_args += f.time_size
with ImplicitBuilder(loop.body.block):

self._build_step_body(dim, functions, eq)

# Compute the yield order to implement the buffer swap in MLIR
yield_args = []
# Skip the induction variable
handled_args = 1
# For each Devito function
for f in functions:
# Get their corresponding iteration fields
fargs = loop.body.block.args[handled_args:handled_args + f.time_size]
# Yield them swapped
yield_args += fargs[1:]
yield_args.append(fargs[0])

handled_args += f.time_size
self.block_args = {(f,t) : loop.body.block.args[1+i] for i, (f,t) in enumerate(self.time_buffers)}
for ((f,t), arg) in self.block_args.items():
arg.name_hint = f"{f.name}_t{t}"

with ImplicitBuilder(loop.body.block):
self._build_step_body(dim, eq)
# Swap buffers through scf.yield
yield_args = [self.block_args[(f, (t+1)%f.time_size)] for (f, t) in self.block_args.keys()]
scf.Yield(*yield_args)

def convert(self, eqs: Iterable[Eq]) -> builtin.ModuleOp:
Expand All @@ -256,7 +233,6 @@ def convert(self, eqs: Iterable[Eq]) -> builtin.ModuleOp:
Region([Block([self._convert_eq(eq) for eq in eqs])])
)


def _ensure_same_type(self, *vals: SSAValue):
if all(isinstance(val.type, builtin.IntegerAttr) for val in vals):
return vals
Expand Down

0 comments on commit a49422f

Please sign in to comment.