Skip to content

Commit

Permalink
Reverse stencil.apply inputs and try to name accordingly.
Browse files Browse the repository at this point in the history
Some tweaks of buffer handling on nd_nwave_devito_nodamp.py.
  • Loading branch information
PapyChacal committed Jul 22, 2023
1 parent d9c4239 commit c389f96
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ def match_and_rewrite(self, op: iet_ssa.Stencil, rewriter: PatternRewriter, /):

for field in op.input_indices:
rewriter.insert_op_before_matched_op(load_op := stencil.LoadOp.get(field))
input_temps.append(load_op.res)
load_op.res.name_hint = field.name_hint + "_temp"
input_temps.insert(0, load_op.res)

rewriter.replace_matched_op(
[
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/ietxdsl/iet_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def get(
stencil.TempType(len(shape), typ)
] * (time_buffers - 1))

for block_arg, idx_arg in zip(block.args, time_indices):
for block_arg, idx_arg in zip(block.args, reversed(inputs)):
name = SSAValue.get(idx_arg).name_hint
if name is None:
continue
Expand Down
11 changes: 7 additions & 4 deletions fast/nd_nwave_devito_nodamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,23 @@ def plot_3dfunc(u):
initdata = u.data[:]

# Run more with no sources now (Not supported in xdsl)
xdslop = XDSLOperator([stencil], name='XDSLOperator')
xdslop = Operator([stencil], name='DevitoOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)

if len(shape) == 3:
if args.plot:
plot_3dfunc(u)

print(norm(u))

devito_output = u.copy()
print(f"devito output norm: {norm(devito_output)}")

# Reset initial data
u.data[:] = initdata

# Run more with no sources now (Not supported in xdsl)
xdslop = XDSLOperator([stencil])
xdslop = XDSLOperator([stencil], name='xDSLOperator')
xdslop.apply(time=time_range.num-1, dt=model.critical_dt)

print(norm(u))
xdsl_output = u.copy()
print(f"xdsl output norm: {norm(xdsl_output)}")

0 comments on commit c389f96

Please sign in to comment.