Skip to content

Commit

Permalink
cleanup: drop deprecated code
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 16, 2023
1 parent b52e50c commit e5dca6d
Show file tree
Hide file tree
Showing 8 changed files with 5 additions and 552 deletions.
3 changes: 0 additions & 3 deletions devito/ir/ietxdsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
from devito.ir.ietxdsl.iet_ssa import * # noqa
from devito.ir.ietxdsl.cgeneration import * # noqa
from devito.ir.ietxdsl.xdsl_passes import transform_devito_to_iet_ssa, transform_devito_xdsl_string # noqa
from devito.ir.ietxdsl.lowering import LowerIetForToScfFor, LowerIetForToScfParallel, DropIetComments, iet_to_standard_mlir # noqa
from devito.ir.ietxdsl.cluster_to_ssa import finalize_module_with_globals, convert_devito_stencil_to_xdsl_stencil # noqa
275 changes: 0 additions & 275 deletions devito/ir/ietxdsl/ietxdsl_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from devito import SpaceDimension
from devito.passes.iet.languages.openmp import OmpRegion

from devito.ir.ietxdsl import (MLContext, IET, Constant, Modi, Block, Statement,
PointerCast, Powi, Initialise, StructDecl, Call)
from devito.tools import as_list
from devito.tools.utils import as_tuple
from devito.types.basic import IndexedData
Expand All @@ -31,70 +29,6 @@

floatingPointLike = ContainerOf(AnyOf([Float16Type, Float32Type, Float64Type]))

def printHeaders(cgen, header_str, headers):
for header in headers:
cgen.printOperation(Statement.get(createStatement(header_str, header)))
cgen.printOperation(Statement.get(createStatement('')))


def printIncludes(cgen, header_str, headers):
for header in headers:
cgen.printOperation(Statement.get(
createStatement(header_str, '"' + header + '"')))
cgen.printOperation(Statement.get(createStatement('')))


def printStructs(cgen, struct_decs):
for struct in struct_decs:
cgen.printOperation(
StructDecl.get(struct.tpname, struct.fields, struct.declname,
struct.pad_bytes))


def print_calls(cgen, calldefs):

for node in calldefs:
call_name = str(node.root.name)

"""
(Pdb) calldefs[0].root.args['parameters']
[buf(x), x_size, f(t, x), otime, ox]
(Pdb) calldefs[0].root.args['parameters'][0]
buf(x)
(Pdb) calldefs[0].root.args['parameters'][0]._C_name
"""
try:
C_names = [str(i._C_name) for i in node.root.args['parameters']]
C_typenames = [str(i._C_typename) for i in node.root.args['parameters']]
C_typeqs = [str(i._C_type_qualifier) for i in node.root.args['parameters']]
prefix = node.root.prefix[0]
retval = node.root.retval
except:
print("Call not translated in calldefs")
return

call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval)

cgen.printCall(call, True)


def createStatement(string="", val=None):
for t in as_tuple(val):
string = string + " " + t

return string


def collectStructs(parameters):
struct_decs = []
struct_strs = []
for i in parameters:
# Bypass a struct decl if it has te same _C_typename
if (i._C_typedecl is not None and str(i._C_typename) not in struct_strs):
struct_decs.append(i._C_typedecl)
struct_strs.append(i._C_typename)
return struct_decs


def calculateAddArguments(arguments):
# Get an input of arguments that are added. In case only one argument remains,
Expand Down Expand Up @@ -252,215 +186,6 @@ def add_to_block(expr, arg_by_expr: dict[Any, Operation], result):
assert False, f'unsupported expr {expr} of type {expr.func}'


def myVisit(node, block: Block, ssa_vals={}):
try:
bool_node = isinstance(
node, nodes.Node), f'Argument must be subclass of Node, found: {node}'
comment_node = isinstance(
node, cgen.Comment), f'Argument must be subclass of Node, found: {node}'
statement_node = isinstance(
node, cgen.Statement), f'Argument must be subclass of Node, found: {node}'
assert bool_node or comment_node or statement_node
except:
print("fail!")

if hasattr(node, 'is_Callable') and node.is_Callable:
return

if isinstance(node, nodes.CallableBody):
return

if isinstance(node, nodes.Expression):
b = Block([i32])
r = []
expr = node.expr
if node.init:
expr_name = expr.args[0]
add_to_block(expr.args[1], {Symbol(s): a for s, a in ssa_vals.items()}, r)

# init = Initialise.get(r[-1].results[0], r[-1].results[0], str(expr_name))
block.add_ops(r)
ssa_vals[str(expr_name)] = r[-1].results[0]
else:
add_to_block(expr, {Symbol(s): a for s, a in ssa_vals.items()}, r)
block.add_ops(r)
return


if isinstance(node, nodes.ExpressionBundle):
assert len(node.children) == 1
for idx in range(len(node.children[0])):
child = node.children[0][idx]
myVisit(child, block, ssa_vals)
return

if isinstance(node, nodes.Iteration):
assert len(node.children) == 1
assert len(node.children[0]) == 1

# Get index variable
dim = node.dim
assert len(node.limits) == 3, "limits should be a (min, max, step) tuple!"

start, end, step = node.limits
try:
step = int(step)
except:
raise ValueError("step must be int!")

# get start, end ssa values
start_ssa_val = ssa_vals[start.name]
end_ssa_val = ssa_vals[end.name]

step_op = arith.Constant.from_int_and_width(step, i32)

block.add_op(step_op)

props = [str(x) for x in node.properties]
pragmas = [str(x) for x in node.pragmas]

subindices = len(node.uindices)

# construct iet for operation
loop = iet_ssa.For.get(start_ssa_val, end_ssa_val, step_op, subindices, props, pragmas)

# extend context to include loop index
ssa_vals[node.index] = loop.block.args[0]

# TODO: add subindices to ctx
for i, uindex in enumerate(node.uindices):
ssa_vals[uindex.name] = loop.block.args[i+1]

# visit the iteration body, adding ops to the loop body
myVisit(node.children[0][0], loop.block, ssa_vals)

# add loop to program
block.add_op(loop)
return

if isinstance(node, nodes.Section):
assert len(node.children) == 1
assert len(node.children[0]) == 1
for content in node.ccode.contents:
if isinstance(content, cgen.Comment):
comment = Statement.get(content)
block.add_ops([comment])
else:
myVisit(node.children[0][0], block, ssa_vals)
return

if isinstance(node, nodes.HaloSpot):
assert len(node.children) == 1
try:
assert isinstance(node.children[0], nodes.Iteration)
except:
assert isinstance(node.children[0], OmpRegion)

myVisit(node.children[0], block, ssa_vals)
return

if isinstance(node, nodes.TimedList):
assert len(node.children) == 1
assert len(node.children[0]) == 1
header = Statement.get(node.header[0])
block.add_ops([header])
myVisit(node.children[0][0], block, ssa_vals)
footer = Statement.get(node.footer[0])
block.add_ops([footer])
return

if isinstance(node, nodes.PointerCast):
statement = node.ccode

assert node.defines[0]._C_name == node.obj._C_name, "This should not happen"

# We want to know the dimensions of the u_vec->data result
# we assume that the result will always be of dim:
# (u_vec->size[i]) for some i
# we further assume, that node.function.symbolic_shape
# is always (u_vec->size[0], u_vec->size[1], ... ,u_vec->size[rank])
# this means that this pretty hacky way works to get the indices of the dims
# in `u_vec->size`
shape = (node.function.symbolic_shape.index(shape) for shape in node.castshape)

arg = ssa_vals[node.function._C_name]
pointer_cast = PointerCast.get(
arg,
statement,
shape,
memref_type_from_indexed_data(node.obj)
)
block.add_ops([pointer_cast])
ssa_vals[node.obj._C_name] = pointer_cast.result
return

if isinstance(node, nodes.List):
# Problem: When a List is ecountered with only body, but no header or footer
# we have a problem
for h in node.header:
myVisit(h, block, ssa_vals)

for b in node.body:
myVisit(b, block, ssa_vals)

for f in node.footer:
myVisit(f, block, ssa_vals)

return

if isinstance(node, nodes.Call):
# Those parameters without associated types aren't printed in the Kernel header
call_name = str(node.name)

try:
C_names = [str(i._C_name) for i in node.arguments]
C_typenames = [str(i._C_typename) for i in node.arguments]
C_typeqs = [str(i._C_type_qualifier) for i in node.arguments]
prefix = ''
retval = ''
except:
# Needs to be fixed
comment = Statement.get(node)
block.add_ops([comment])
print(f"Call {node.name} instance translated as comment")
return

call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval)
block.add_ops([call])

print(f"Call {node.name} translated")
return

if isinstance(node, nodes.Conditional):
# Those parameters without associated types aren't printed in the Kernel header
print("Conditional placement skipping")
return

if isinstance(node, nodes.Definition):
print("Translating definition")
comment = Statement.get(node)
block.add_ops([comment])
return

if isinstance(node, cgen.Comment):
# cgen.Comment as Statement
comment = Statement.get(node)
block.add_ops([comment])
return

if isinstance(node, cgen.Statement):
comment = Statement.get(node)
block.add_ops([comment])
return

if isinstance(node, cgen.Line):
comment = Statement.get(node)
block.add_ops([comment])
return

#raise TypeError(f'Unsupported type of node: {type(node)}, {vars(node)}')


def get_arg_types(symbols):
processed = []
for symbol in symbols:
Expand Down
Loading

0 comments on commit e5dca6d

Please sign in to comment.