diff --git a/devito/ir/ietxdsl/__init__.py b/devito/ir/ietxdsl/__init__.py index 0923e614e2..2ba899e861 100644 --- a/devito/ir/ietxdsl/__init__.py +++ b/devito/ir/ietxdsl/__init__.py @@ -1,5 +1,2 @@ -from devito.ir.ietxdsl.iet_ssa import * # noqa -from devito.ir.ietxdsl.cgeneration import * # noqa -from devito.ir.ietxdsl.xdsl_passes import transform_devito_to_iet_ssa, transform_devito_xdsl_string # noqa from devito.ir.ietxdsl.lowering import LowerIetForToScfFor, LowerIetForToScfParallel, DropIetComments, iet_to_standard_mlir # noqa from devito.ir.ietxdsl.cluster_to_ssa import finalize_module_with_globals, convert_devito_stencil_to_xdsl_stencil # noqa diff --git a/devito/ir/ietxdsl/ietxdsl_functions.py b/devito/ir/ietxdsl/ietxdsl_functions.py index b4a9498fb1..e031661686 100644 --- a/devito/ir/ietxdsl/ietxdsl_functions.py +++ b/devito/ir/ietxdsl/ietxdsl_functions.py @@ -11,8 +11,6 @@ from devito import SpaceDimension from devito.passes.iet.languages.openmp import OmpRegion -from devito.ir.ietxdsl import (MLContext, IET, Constant, Modi, Block, Statement, - PointerCast, Powi, Initialise, StructDecl, Call) from devito.tools import as_list from devito.tools.utils import as_tuple from devito.types.basic import IndexedData @@ -31,70 +29,6 @@ floatingPointLike = ContainerOf(AnyOf([Float16Type, Float32Type, Float64Type])) -def printHeaders(cgen, header_str, headers): - for header in headers: - cgen.printOperation(Statement.get(createStatement(header_str, header))) - cgen.printOperation(Statement.get(createStatement(''))) - - -def printIncludes(cgen, header_str, headers): - for header in headers: - cgen.printOperation(Statement.get( - createStatement(header_str, '"' + header + '"'))) - cgen.printOperation(Statement.get(createStatement(''))) - - -def printStructs(cgen, struct_decs): - for struct in struct_decs: - cgen.printOperation( - StructDecl.get(struct.tpname, struct.fields, struct.declname, - struct.pad_bytes)) - - -def print_calls(cgen, calldefs): - - for node in calldefs: - call_name = str(node.root.name) - - """ - (Pdb) calldefs[0].root.args['parameters'] - [buf(x), x_size, f(t, x), otime, ox] - (Pdb) calldefs[0].root.args['parameters'][0] - buf(x) - (Pdb) calldefs[0].root.args['parameters'][0]._C_name - """ - try: - C_names = [str(i._C_name) for i in node.root.args['parameters']] - C_typenames = [str(i._C_typename) for i in node.root.args['parameters']] - C_typeqs = [str(i._C_type_qualifier) for i in node.root.args['parameters']] - prefix = node.root.prefix[0] - retval = node.root.retval - except: - print("Call not translated in calldefs") - return - - call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval) - - cgen.printCall(call, True) - - -def createStatement(string="", val=None): - for t in as_tuple(val): - string = string + " " + t - - return string - - -def collectStructs(parameters): - struct_decs = [] - struct_strs = [] - for i in parameters: - # Bypass a struct decl if it has te same _C_typename - if (i._C_typedecl is not None and str(i._C_typename) not in struct_strs): - struct_decs.append(i._C_typedecl) - struct_strs.append(i._C_typename) - return struct_decs - def calculateAddArguments(arguments): # Get an input of arguments that are added. In case only one argument remains, @@ -252,215 +186,6 @@ def add_to_block(expr, arg_by_expr: dict[Any, Operation], result): assert False, f'unsupported expr {expr} of type {expr.func}' -def myVisit(node, block: Block, ssa_vals={}): - try: - bool_node = isinstance( - node, nodes.Node), f'Argument must be subclass of Node, found: {node}' - comment_node = isinstance( - node, cgen.Comment), f'Argument must be subclass of Node, found: {node}' - statement_node = isinstance( - node, cgen.Statement), f'Argument must be subclass of Node, found: {node}' - assert bool_node or comment_node or statement_node - except: - print("fail!") - - if hasattr(node, 'is_Callable') and node.is_Callable: - return - - if isinstance(node, nodes.CallableBody): - return - - if isinstance(node, nodes.Expression): - b = Block([i32]) - r = [] - expr = node.expr - if node.init: - expr_name = expr.args[0] - add_to_block(expr.args[1], {Symbol(s): a for s, a in ssa_vals.items()}, r) - - # init = Initialise.get(r[-1].results[0], r[-1].results[0], str(expr_name)) - block.add_ops(r) - ssa_vals[str(expr_name)] = r[-1].results[0] - else: - add_to_block(expr, {Symbol(s): a for s, a in ssa_vals.items()}, r) - block.add_ops(r) - return - - - if isinstance(node, nodes.ExpressionBundle): - assert len(node.children) == 1 - for idx in range(len(node.children[0])): - child = node.children[0][idx] - myVisit(child, block, ssa_vals) - return - - if isinstance(node, nodes.Iteration): - assert len(node.children) == 1 - assert len(node.children[0]) == 1 - - # Get index variable - dim = node.dim - assert len(node.limits) == 3, "limits should be a (min, max, step) tuple!" - - start, end, step = node.limits - try: - step = int(step) - except: - raise ValueError("step must be int!") - - # get start, end ssa values - start_ssa_val = ssa_vals[start.name] - end_ssa_val = ssa_vals[end.name] - - step_op = arith.Constant.from_int_and_width(step, i32) - - block.add_op(step_op) - - props = [str(x) for x in node.properties] - pragmas = [str(x) for x in node.pragmas] - - subindices = len(node.uindices) - - # construct iet for operation - loop = iet_ssa.For.get(start_ssa_val, end_ssa_val, step_op, subindices, props, pragmas) - - # extend context to include loop index - ssa_vals[node.index] = loop.block.args[0] - - # TODO: add subindices to ctx - for i, uindex in enumerate(node.uindices): - ssa_vals[uindex.name] = loop.block.args[i+1] - - # visit the iteration body, adding ops to the loop body - myVisit(node.children[0][0], loop.block, ssa_vals) - - # add loop to program - block.add_op(loop) - return - - if isinstance(node, nodes.Section): - assert len(node.children) == 1 - assert len(node.children[0]) == 1 - for content in node.ccode.contents: - if isinstance(content, cgen.Comment): - comment = Statement.get(content) - block.add_ops([comment]) - else: - myVisit(node.children[0][0], block, ssa_vals) - return - - if isinstance(node, nodes.HaloSpot): - assert len(node.children) == 1 - try: - assert isinstance(node.children[0], nodes.Iteration) - except: - assert isinstance(node.children[0], OmpRegion) - - myVisit(node.children[0], block, ssa_vals) - return - - if isinstance(node, nodes.TimedList): - assert len(node.children) == 1 - assert len(node.children[0]) == 1 - header = Statement.get(node.header[0]) - block.add_ops([header]) - myVisit(node.children[0][0], block, ssa_vals) - footer = Statement.get(node.footer[0]) - block.add_ops([footer]) - return - - if isinstance(node, nodes.PointerCast): - statement = node.ccode - - assert node.defines[0]._C_name == node.obj._C_name, "This should not happen" - - # We want to know the dimensions of the u_vec->data result - # we assume that the result will always be of dim: - # (u_vec->size[i]) for some i - # we further assume, that node.function.symbolic_shape - # is always (u_vec->size[0], u_vec->size[1], ... ,u_vec->size[rank]) - # this means that this pretty hacky way works to get the indices of the dims - # in `u_vec->size` - shape = (node.function.symbolic_shape.index(shape) for shape in node.castshape) - - arg = ssa_vals[node.function._C_name] - pointer_cast = PointerCast.get( - arg, - statement, - shape, - memref_type_from_indexed_data(node.obj) - ) - block.add_ops([pointer_cast]) - ssa_vals[node.obj._C_name] = pointer_cast.result - return - - if isinstance(node, nodes.List): - # Problem: When a List is ecountered with only body, but no header or footer - # we have a problem - for h in node.header: - myVisit(h, block, ssa_vals) - - for b in node.body: - myVisit(b, block, ssa_vals) - - for f in node.footer: - myVisit(f, block, ssa_vals) - - return - - if isinstance(node, nodes.Call): - # Those parameters without associated types aren't printed in the Kernel header - call_name = str(node.name) - - try: - C_names = [str(i._C_name) for i in node.arguments] - C_typenames = [str(i._C_typename) for i in node.arguments] - C_typeqs = [str(i._C_type_qualifier) for i in node.arguments] - prefix = '' - retval = '' - except: - # Needs to be fixed - comment = Statement.get(node) - block.add_ops([comment]) - print(f"Call {node.name} instance translated as comment") - return - - call = Call(call_name, C_names, C_typenames, C_typeqs, prefix, retval) - block.add_ops([call]) - - print(f"Call {node.name} translated") - return - - if isinstance(node, nodes.Conditional): - # Those parameters without associated types aren't printed in the Kernel header - print("Conditional placement skipping") - return - - if isinstance(node, nodes.Definition): - print("Translating definition") - comment = Statement.get(node) - block.add_ops([comment]) - return - - if isinstance(node, cgen.Comment): - # cgen.Comment as Statement - comment = Statement.get(node) - block.add_ops([comment]) - return - - if isinstance(node, cgen.Statement): - comment = Statement.get(node) - block.add_ops([comment]) - return - - if isinstance(node, cgen.Line): - comment = Statement.get(node) - block.add_ops([comment]) - return - - #raise TypeError(f'Unsupported type of node: {type(node)}, {vars(node)}') - - def get_arg_types(symbols): processed = [] for symbol in symbols: diff --git a/devito/ir/ietxdsl/xdsl_passes.py b/devito/ir/ietxdsl/xdsl_passes.py deleted file mode 100644 index 44e56c5a71..0000000000 --- a/devito/ir/ietxdsl/xdsl_passes.py +++ /dev/null @@ -1,211 +0,0 @@ -from devito import Operator - -from devito.ir import PointerCast, FindNodes -from devito.ir.iet import FindSymbols -from devito.ir.iet.nodes import CallableBody, MetaCall, Definition, Dereference, Prodder # noqa - -from devito.ir.ietxdsl import (MLContext, IET, CGeneration, - ietxdsl_functions, Callable) - -from devito.ir.ietxdsl.ietxdsl_functions import collectStructs, get_arg_types -from xdsl.dialects.builtin import Builtin, i32 -from xdsl.dialects import builtin, func -from xdsl.ir import Block, Region - - -def transform_devito_xdsl_string(op: Operator): - - """ - Transform a Devito Operator to an XDSL code string. - Parameters - ---------- - op : Operator - A Devito Operator. - Returns - ------- - cgen.str - A cgen string with the transformed code. - """ - - ctx = MLContext() - Builtin(ctx) - iet = IET(ctx) - - cgen = CGeneration() - - # Print headers/includes/Structs - ietxdsl_functions.printHeaders(cgen, "#define", op._headers) - ietxdsl_functions.printIncludes(cgen, "#include", op._includes) - ietxdsl_functions.printStructs(cgen, collectStructs(op.parameters)) - - # Check for the existence of funcs in the operator (print devito metacalls) - op_funcs = [value for _, value in op._func_table.items()] - # Print calls - ietxdsl_functions.print_calls(cgen, op_funcs) - # Visit and print the main kernel - call_obj = _op_to_func(op) - cgen.printCallable(call_obj) - - # After finishing kernels, now we check the rest of the functions - module = builtin.ModuleOp.from_region_or_ops([call_obj]) - for op_func in op_funcs: - op = op_func.root - name = op.name - - # Those parameters without associated types aren't printed in the Kernel header - op_param_names = [s._C_name for s in FindSymbols('defines').visit(op)] - op_header_params = [i._C_name for i in list(op.parameters)] - op_types = [i._C_typename for i in list(op.parameters)] - op_type_qs = [i._C_type_qualifier for i in list(op.parameters)] - prefix = '-'.join(op.prefix) - retval = str(op.retval) - # import pdb;pdb.set_trace() - b = Block([i32] * len(op_param_names)) - d = {name: register for name, register in zip(op_param_names, b.args)} - - # Add Allocs - for op_alloc in op.body.allocs: - ietxdsl_functions.myVisit(op_alloc, block=b, ssa_vals=d) - - cgen.print('') - # Add obj defs - for op_obj in op.body.objs: - ietxdsl_functions.myVisit(op_obj, block=b, ssa_vals=d) - - # import pdb;pdb.set_trace() - - # Add Casts - for cast in FindNodes(PointerCast).visit(op.body): - ietxdsl_functions.myVisit(cast, block=b, ssa_vals=d) - - call_obj = Callable.get(name, op_param_names, op_header_params, op_types, - op_type_qs, retval, prefix, b) - - for body_i in op.body.body: - # Comments - if body_i.args.get('body') != (): - for body_j in body_i.body: - # Casts - ietxdsl_functions.myVisit(body_j, block=b, ssa_vals=d) - else: - ietxdsl_functions.myVisit(body_i, block=b, ssa_vals=d) - - # print Kernel - - # Add frees - for op_free in op.body.frees: - ietxdsl_functions.myVisit(op_free, block=b, ssa_vals=d) - - cgen.printCallable(call_obj) - module.regions[0].blocks[0].add_op(call_obj) - - from xdsl.printer import Printer - Printer().print(module) - return cgen.str() - - -def _op_to_func(op: Operator): - # Visit the Operator body - assert isinstance(op.body, CallableBody) - - - # Scan an Operator - # Those parameters without associated types aren't printed in the Kernel header - # # import pdb;pdb.set_trace() - op_symbols = FindSymbols('defines').visit(op) - op_param_names = [s._C_name for s in op_symbols] - op_header_params = [i._C_name for i in list(op.parameters)] - op_types = [i._C_typename for i in list(op.parameters)] - op_type_qs = [i._C_type_qualifier for i in list(op.parameters)] - prefix = '-'.join(op.prefix) - retv = str(op.retval) - - # # import pdb;pdb.set_trace() - - # Game is here we start a dict from op params, focus - arg_types = get_arg_types(op.parameters) - # b = Block([i32] * len(op_param_names)) - block = Block(arg_types) - ssa_val_dict = {param._C_name: val for param, val in zip(op.parameters, block.args)} - - # Add Casts - for cast in FindNodes(PointerCast).visit(op.body): - ietxdsl_functions.myVisit(cast, block=block, ssa_vals=ssa_val_dict) - - for i in op.body.body: - # Comments - # # import pdb;pdb.set_trace() - if i.args.get('body') != (): - for body_j in i.body: - # Casts - ietxdsl_functions.myVisit(body_j, block=block, ssa_vals=ssa_val_dict) - else: - ietxdsl_functions.myVisit(i, block=block, ssa_vals=ssa_val_dict) - - # add a trailing return - block.add_op(func.Return()) - - func_op = func.FuncOp.from_region(str(op.name), arg_types, [], Region([block])) - - func_op.attributes['param_names'] = builtin.ArrayAttr([ - builtin.StringAttr(str(param._C_name)) for param in op.parameters - ]) - - return func_op - - -def transform_devito_to_iet_ssa(op: Operator): - # Check for the existence of funcs in the operator (print devito metacalls) - op_funcs = [value for _, value in op._func_table.items()] - # Print calls - call_obj = _op_to_func(op) - - # After finishing kernels, now we check the rest of the functions - module = builtin.ModuleOp.from_region_or_ops([call_obj]) - for op_func in op_funcs: - op = op_func.root - name = op.name - - # Those parameters without associated types aren't printed in the Kernel header - op_param_names = [s._C_name for s in FindSymbols('defines').visit(op)] - op_header_params = [i._C_name for i in list(op.parameters)] - op_types = [i._C_typename for i in list(op.parameters)] - op_type_qs = [i._C_type_qualifier for i in list(op.parameters)] - prefix = '-'.join(op.prefix) - retval = str(op.retval) - b = Block([i32] * len(op_param_names)) - d = {name: register for name, register in zip(op_param_names, b.args)} - - # Add Allocs - for op_alloc in op.body.allocs: - ietxdsl_functions.myVisit(op_alloc, block=b, ssa_vals=d) - - # Add obj defs - for op_obj in op.body.objs: - ietxdsl_functions.myVisit(op_obj, block=b, ssa_vals=d) - - # Add Casts - for cast in FindNodes(PointerCast).visit(op.body): - ietxdsl_functions.myVisit(cast, block=b, ssa_vals=d) - - call_obj = Callable.get(name, op_param_names, op_header_params, op_types, - op_type_qs, retval, prefix, b) - - for body_i in op.body.body: - # Comments - if body_i.args.get('body') != (): - for body_j in body_i.body: - # Casts - ietxdsl_functions.myVisit(body_j, block=b, ssa_vals=d) - else: - ietxdsl_functions.myVisit(body_i, block=b, ssa_vals=d) - - # print Kernel - - # Add frees - for op_free in op.body.frees: - ietxdsl_functions.myVisit(op_free, block=b, ssa_vals=d) - - module.regions[0].blocks[0].add_op(call_obj) - - return module diff --git a/devito/xdslpasses/__init__.py b/devito/xdslpasses/__init__.py index 20219aecd7..4eb86f53f3 100644 --- a/devito/xdslpasses/__init__.py +++ b/devito/xdslpasses/__init__.py @@ -1 +1 @@ -from .iet import Callable, CGeneration # noqa \ No newline at end of file +from .iet import Callable # noqa \ No newline at end of file diff --git a/devito/xdslpasses/iet/__init__.py b/devito/xdslpasses/iet/__init__.py index b7c22c7a74..feeeddfa40 100644 --- a/devito/xdslpasses/iet/__init__.py +++ b/devito/xdslpasses/iet/__init__.py @@ -1 +1 @@ -from .parpragma import Callable, CGeneration # noqa +from .parpragma import Callable # noqa diff --git a/tests/test_xdsl_iet.py b/tests/test_xdsl_iet.py index c77b06b44c..b0898e3a55 100644 --- a/tests/test_xdsl_iet.py +++ b/tests/test_xdsl_iet.py @@ -1,16 +1,13 @@ from devito import Grid, TimeFunction, Eq, Operator from devito.tools import as_tuple -from devito.ir.ietxdsl import (MLContext, CGeneration, Powi, IET, Callable, - Block, Iteration, Initialise, - floatingPointLike) - from devito.ir.iet import retrieve_iteration_tree from xdsl.dialects.builtin import ModuleOp, Builtin, i32, f32 from xdsl.printer import Printer from xdsl.dialects.arith import Addi, Constant, Subi +from xdsl.dialects.experimental.math import FPowIOp from xdsl.dialects import memref, arith from xdsl.ir import Operation, Block, Region import pytest @@ -53,62 +50,9 @@ def test_powi(): mod = ModuleOp([ cst1 := Constant.from_int_and_width(1, i32), - ut1 := Powi.get(cst1, cst1), + ut1 := FPowIOp.get(cst1, cst1), ]) - # printer = Printer() - # printer.print_op(mod) - - -def test_blockIteration(): - - mod = ModuleOp([ - - Iteration.get(["affine", "sequential"], ("time_m", "time_M", "1"),"time_loop", - Block.from_callable([ - i32, i32, i32 - ], lambda time, t0, t1: [ - Iteration. - get(["affine", "parallel", "skewable"], - ("x_m", "x_M", "1"),"x_loop", - Block.from_callable([i32], lambda x: [ - Iteration.get( - [ - "affine", - "parallel", "skewable", "vector-dim" - ], ("y_m", "y_M", "1"),"y_loop", - Block.from_callable([i32], lambda y: [ - cst1 := Constant.from_int_and_width(1, i32), - x1 := Addi(x, cst1), - y1 := Addi(y, cst1), - ])) - ])) - ])) - ]) - - printer = Printer() - printer.print_op(mod) - - -def test_callable(): - - a = Constant.from_int_and_width(1, i32) - b = Constant.from_int_and_width(2, i32) - - # Operation to add these constants - c = Addi(a, b) - - block0 = Block([a, b, c]) - - mod = ModuleOp([ - Callable.get( - "kernel", ["u"], ["u"], ["struct dataobj*"], ["restrict"], "int", "", - block0) - ]) - - printer = Printer() - printer.print_op(mod) - @pytest.mark.xfail(reason="Deprecated, will be dropped") def test_devito_iet(): diff --git a/tests/test_xdsl_operator.py b/tests/test_xdsl_operator.py index 9af768b3e3..b7401b70e6 100644 --- a/tests/test_xdsl_operator.py +++ b/tests/test_xdsl_operator.py @@ -1,7 +1,6 @@ from devito import Grid, TimeFunction, Eq, XDSLOperator, Operator -from devito.ir.ietxdsl.xdsl_passes import transform_devito_xdsl_string -# flake8: noqa from devito.operator.xdsl_operator import XDSLOperator +# flake8: noqa def test_create_xdsl_operator(): diff --git a/tests/test_xdsl_simple.py b/tests/test_xdsl_simple.py index 3605dad644..2d80ce46b1 100644 --- a/tests/test_xdsl_simple.py +++ b/tests/test_xdsl_simple.py @@ -1,5 +1,4 @@ from devito import Grid, TimeFunction, Eq, XDSLOperator, Operator -from devito.ir.ietxdsl.xdsl_passes import transform_devito_xdsl_string # flake8: noqa from devito.operator.xdsl_operator import XDSLOperator