Skip to content

Commit

Permalink
Merge branch 'master' into emilien/mixed-function-test
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas authored May 27, 2024
2 parents 98fd031 + 134f237 commit 3fa1f13
Show file tree
Hide file tree
Showing 58 changed files with 51 additions and 30 deletions.
3 changes: 3 additions & 0 deletions devito/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from devito.core.operator import CoreOperator, CustomOperator, ParTile
from devito.exceptions import InvalidOperator
from devito.ir.iet import Callable, MetaCall
from devito.ir.iet.nodes import Section
from devito.ir.iet.visitors import FindNodes
from devito.logger import info, perf
from devito.mpi import MPI
from devito.operator.profiling import create_profile
Expand Down Expand Up @@ -277,6 +279,7 @@ def _build(cls, expressions, **kwargs):
op._dimensions = set().union(*[e.dimensions for e in irs.expressions])
op._dtype, op._dspace = irs.clusters.meta
op._profiler = profiler
kwargs['xdsl_num_sections'] = len(FindNodes(Section).visit(irs.iet))
module = cls._lower_stencil(irs.expressions, **kwargs)
op._module = module

Expand Down
11 changes: 10 additions & 1 deletion devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ class ExtractDevitoStencilConversion:
eqs: list[LoweredEq]
block: Block
temps: dict[tuple[DiscreteFunction, int], SSAValue]
symbol_values: dict[str, SSAValue]
time_offs: int

def __init__(self):
self.temps = dict()
self.symbol_values = dict()

time_offs: int

Expand All @@ -75,6 +77,13 @@ def convert_function_eq(self, eq: LoweredEq, **kwargs):
# Get the function carriers of the equation
self._build_step_body(step_dim, eq)

def convert_symbol_eq(self, symbol: Symbol, rhs: LoweredEq, **kwargs):
"""
Convert a symbol equation to xDSL.
"""
self.symbol_values[symbol.name] = self._visit_math_nodes(None, rhs, None)
self.symbol_values[symbol.name].name_hint = symbol.name

def _convert_eq(self, eq: LoweredEq, **kwargs):
"""
# Docs here Need rewriting
Expand Down Expand Up @@ -122,7 +131,7 @@ def _convert_eq(self, eq: LoweredEq, **kwargs):
f"Function of type {type(write_function.function)} not supported" # noqa
)
case Symbol():
self.convert_symbol_eq(write_function, eq, **kwargs)
self.convert_symbol_eq(write_function, eq.rhs, **kwargs)
case _:
raise NotImplementedError(f"LHS of type {type(write_function)} not supported") # noqa

Expand Down
2 changes: 2 additions & 0 deletions devito/ir/ietxdsl/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def apply_timers(module, **kwargs):
"""
Apply timers to a module
"""
if kwargs['xdsl_num_sections'] < 1:
return
name = kwargs.get("name", "Kernel")
grpa = GreedyRewritePatternApplier([MakeFunctionTimed(name)])
PatternRewriteWalker(grpa, walk_regions_first=True).rewrite_module(module)
15 changes: 1 addition & 14 deletions tests/test_xdsl_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from devito import (Grid, TensorTimeFunction, VectorTimeFunction, div, grad, diag, solve,
Operator, Eq, Constant, norm, SpaceDimension)
from devito.types import Symbol, Array, Function, TimeFunction
from devito.types import Array, Function, TimeFunction

from xdsl.dialects.scf import For, Yield
from xdsl.dialects.arith import Addi
Expand Down Expand Up @@ -789,19 +789,6 @@ def test_function_IV():

class TestOperatorUnsupported(object):

@pytest.mark.xfail(reason="Symbols are not supported in xDSL yet")
def test_symbol_I(self):
# Define a simple Devito a = 1 operator

a = Symbol('a')
eq0 = Eq(a, 1)

op = Operator([eq0], opt='xdsl')

op.apply()

assert a == 1

@pytest.mark.xfail(reason="stencil.return operation does not verify for i64")
def test_forward_assignment(self):
# simple forward assignment
Expand Down
21 changes: 20 additions & 1 deletion tests/test_xdsl_op_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# flake8: noqa

from xdsl.dialects.scf import For, Yield
from xdsl.dialects.arith import Addi
from xdsl.dialects.arith import Addi, Constant
from xdsl.dialects.func import Call, Return
from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp
from xdsl.dialects.llvm import LLVMPointerType
from xdsl.printer import Printer

from devito.types.basic import Symbol

def test_udx():

# Define a simple Devito Operator
Expand Down Expand Up @@ -106,6 +108,23 @@ def test_u_and_v_conversion():
assert type(ops[8] == StoreOp)
assert type(ops[9] == Return)

def test_symbol_I():
# Define a simple Devito a = 1 operator

a = Symbol('a')
eq0 = Eq(a, 1)

op = Operator([eq0], opt='xdsl')

op.apply()

assert len(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops) == 2

ops = list(op._module.regions[0].blocks[0].ops.first.body.blocks[0].ops)
assert isinstance(ops[0], Constant)
assert ops[0].result.name_hint == a.name
assert type(ops[0] == Return)


# This test should fail, as we are trying to use an inplace operation
@pytest.mark.xfail(reason="Cannot store to a field that is loaded from")
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from devito import (Grid, TimeFunction, Eq, solve, Operator,
Constant, norm, configuration)
from examples.cfd import init_hat
from fast.bench_utils import plot_2dfunc
from xdsl_examples.bench_utils import plot_2dfunc


parser = argparse.ArgumentParser(description='Process arguments.')

Expand All @@ -22,10 +23,11 @@
type=int, help="Simulation time in millisecond")
parser.add_argument("-bls", "--blevels", default=2, type=int, nargs="+",
help="Block levels")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D")
parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run")
parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run")
args = parser.parse_args()
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D")


mpiconf = configuration['mpi']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from devito import (Grid, TimeFunction, Eq, solve, Constant,
norm, Operator, configuration)
from fast.bench_utils import plot_3dfunc
from xdsl_examples.bench_utils import plot_3dfunc

parser = argparse.ArgumentParser(description='Process arguments.')

Expand Down
12 changes: 5 additions & 7 deletions xdsl-examples/elastic2d.py → xdsl_examples/elastic2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
type=int, help="Space order of the simulation")
parser.add_argument("-nt", "--nt", default=40,
type=int, help="Simulation time in millisecond")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot3D")

parser.add_argument("-devito", "--devito", default=False, type=bool, help="Devito run")
parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run")
parser.add_argument("-plot", "--plot", default=False, type=bool, help="Plot2D")
Expand Down Expand Up @@ -170,11 +168,11 @@ def wavelet(self, f0, t):

if args.plot:
# Save the plotted images locally
plt.imsave('/home/gb4018/workspace/xdslproject/devito/fast/v0.pdf', v[0].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/fast/v1.pdf', v[1].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/fast/tau00.pdf', tau[0, 0].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/fast/tau11.pdf', tau[1, 1].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/fast/tau01.pdf', tau[0, 1].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/xdsl-examples/v0.pdf', v[0].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/xdsl-examples/v1.pdf', v[1].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/xdsl-examples/tau00.pdf', tau[0, 0].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/xdsl-examples/tau11.pdf', tau[1, 1].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")
plt.imsave('/home/gb4018/workspace/xdslproject/devito/xdsl-examples/tau01.pdf', tau[0, 1].data_with_halo[0], vmin=-.5*1e-2, vmax=.5*1e-2, cmap="seismic")

assert np.allclose(v_xdsl[0].data, v[0].data, rtol=1e-8)
assert np.allclose(v_xdsl[1].data, v[1].data, rtol=1e-8)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito import (TimeFunction, Eq, Operator, solve, configuration)
from examples.seismic import RickerSource
from examples.seismic import Model, TimeAxis
from fast.bench_utils import plot_2dfunc
from xdsl_examples.bench_utils import plot_2dfunc
from devito.tools import as_tuple

import argparse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
configuration)
from examples.seismic import RickerSource
from examples.seismic import Model, TimeAxis
from fast.bench_utils import plot_3dfunc
from xdsl_examples.bench_utils import plot_3dfunc
from devito.tools import as_tuple

import argparse
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file added xdsl_examples/tau00.pdf
Binary file not shown.
Binary file added xdsl_examples/tau01.pdf
Binary file not shown.
Binary file added xdsl_examples/tau11.pdf
Binary file not shown.
Binary file added xdsl_examples/v0.pdf
Binary file not shown.
Binary file added xdsl_examples/v1.pdf
Binary file not shown.
2 changes: 1 addition & 1 deletion xdsl-examples/wave2d_b.py → xdsl_examples/wave2d_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from devito.tools import as_tuple

import argparse
from fast.bench_utils import plot_2dfunc
from xdsl_examples.bench_utils import plot_2dfunc

np.set_printoptions(threshold=np.inf)

Expand Down
3 changes: 2 additions & 1 deletion xdsl-examples/wave3d_b.py → xdsl_examples/wave3d_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from devito.tools import as_tuple

import argparse
from fast.bench_utils import plot_3dfunc
from xdsl_examples.bench_utils import plot_3dfunc


np.set_printoptions(threshold=np.inf)

Expand Down
File renamed without changes.

0 comments on commit 3fa1f13

Please sign in to comment.