diff --git a/.github/workflows/ci-lit.yml b/.github/workflows/ci-lit.yml new file mode 100644 index 0000000000..c97f2b5794 --- /dev/null +++ b/.github/workflows/ci-lit.yml @@ -0,0 +1,42 @@ +name: CI - lit-based Testing + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + # Trigger the workflow on push or pull request, + # but only for the master branch + push: + branches: + - master + pull_request: + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + container: papychacal/xdsl-llvm:98e674c9f16d677d95c67bc130e267fae331e43c + steps: + - name: Checkout Devito + uses: actions/checkout@v3 + + - name: Install native dependencies + run: | + apt-get update && apt install curl mpich -y + + - name: Upgrade pip + run: | + pip install --upgrade pip + + - name: Install requirements and xDSL + run: | + pip install -e .[tests] + pip install mpi4py + pip install git+https://github.com/xdslproject/xdsl@f8bb935880276cf077e0a80f1905105d0a98eb33 + + - name: Execute lit tests + run: | + export PYTHONPATH=$(pwd) + export PATH=/xdsl-sc/llvm-project/build/bin/:$PATH + lit -v tests/filecheck/ diff --git a/devito/xdsl_core/xdsl_cpu.py b/devito/xdsl_core/xdsl_cpu.py index 24e5d38ce5..d4d9f5e3db 100644 --- a/devito/xdsl_core/xdsl_cpu.py +++ b/devito/xdsl_core/xdsl_cpu.py @@ -501,7 +501,6 @@ def _jit_compile(self): # mlir-translate to translate to LLVM-IR mlir_translate_cmd = 'mlir-translate --mlir-to-llvmir' out = self.compile(mlir_translate_cmd, out2.getvalue()) - # Compile with clang and get LLVM-IR clang_cmd = f'{cc} {cflags} -o {self._tf.name} {self._interop_tf.name} -xir -' # noqa out = self.compile(clang_cmd, out) diff --git a/requirements-testing.txt b/requirements-testing.txt index 0f88276721..34b75980e4 100644 --- a/requirements-testing.txt +++ b/requirements-testing.txt @@ -5,4 +5,5 @@ codecov flake8>=2.1.0 nbval scipy +lit<19.0.0 pooch; python_version >= "3.8" diff --git a/tests/filecheck/.lit_test_times.txt b/tests/filecheck/.lit_test_times.txt new file mode 100644 index 0000000000..b30dcdd049 --- /dev/null +++ b/tests/filecheck/.lit_test_times.txt @@ -0,0 +1,3 @@ +3.295047e-01 example.mlir +3.190060e-01 shape_inference.mlir +5.976017e-01 xdsl_pipeline.mlir diff --git a/tests/filecheck/lit.cfg b/tests/filecheck/lit.cfg new file mode 100644 index 0000000000..0d9ed1830a --- /dev/null +++ b/tests/filecheck/lit.cfg @@ -0,0 +1,21 @@ +import lit.formats +import os + +config.test_source_root = os.path.dirname(__file__) +xdsl_src = os.path.dirname(os.path.dirname(config.test_source_root)) + +config.name = "xDSL" +config.test_format = lit.formats.ShTest(preamble_commands=[f"cd {xdsl_src}"]) +config.suffixes = ['.test', '.mlir', '.py'] + +xdsl_opt = "xdsl/tools/xdsl-opt" +xdsl_run = "xdsl/tools/xdsl_run.py" +irdl_to_pyrdl = "xdsl/tools/irdl_to_pyrdl.py" + +config.substitutions.append(('XDSL_ROUNDTRIP', "xdsl-opt %s --print-op-generic --split-input-file | xdsl-opt --split-input-file | filecheck %s")) +config.substitutions.append(("XDSL_GENERIC_ROUNDTRIP", "xdsl-opt %s --print-op-generic --split-input-file | filecheck %s --check-prefix=CHECK-GENERIC")) +if "COVERAGE" in lit_config.params: + config.substitutions.append(('xdsl-opt', f"coverage run {xdsl_opt}")) + config.substitutions.append(('xdsl-run', f"coverage run {xdsl_run}")) + config.substitutions.append(('irdl-to-pyrdl', f"coverage run {irdl_to_pyrdl}")) + config.substitutions.append(('python', f"coverage run")) diff --git a/tests/filecheck/shape_inference.mlir b/tests/filecheck/shape_inference.mlir new file mode 100644 index 0000000000..77194b9944 --- /dev/null +++ b/tests/filecheck/shape_inference.mlir @@ -0,0 +1,152 @@ +// RUN: xdsl-opt -p stencil-shape-inference %s | filecheck %s + +builtin.module { + func.func @Kernel(%f2_vec0 : !stencil.field<[-2,5]x[-2,5]xf32>, %f2_vec1 : !stencil.field<[-2,5]x[-2,5]xf32>, %timers : !llvm.ptr) { + %0 = func.call @timer_start() : () -> f64 + %time_m = arith.constant 0 : index + %time_M = arith.constant 1 : index + %1 = arith.constant 1 : index + %2 = arith.addi %time_M, %1 : index + %step = arith.constant 1 : index + %3, %4 = scf.for %time = %time_m to %2 step %step iter_args(%f2_t0 = %f2_vec0, %f2_t1 = %f2_vec1) -> (!stencil.field<[-2,5]x[-2,5]xf32>, !stencil.field<[-2,5]x[-2,5]xf32>) { + %f2_t0_temp = stencil.load %f2_t0 : !stencil.field<[-2,5]x[-2,5]xf32> -> !stencil.temp + %f2_t1_temp = stencil.apply(%f2_t0_blk = %f2_t0_temp : !stencil.temp) -> (!stencil.temp) { + %5 = arith.constant 5.000000e-01 : f32 + %h_x = arith.constant 5.000000e-01 : f32 + %6 = arith.constant -2 : i64 + %7 = "math.fpowi"(%h_x, %6) : (f32, i64) -> f32 + %8 = stencil.access %f2_t0_blk[-1, 0] : !stencil.temp + %9 = arith.mulf %7, %8 : f32 + %h_x_1 = arith.constant 5.000000e-01 : f32 + %10 = arith.constant -2 : i64 + %11 = "math.fpowi"(%h_x_1, %10) : (f32, i64) -> f32 + %12 = stencil.access %f2_t0_blk[1, 0] : !stencil.temp + %13 = arith.mulf %11, %12 : f32 + %14 = arith.constant -2.000000e+00 : f32 + %h_x_2 = arith.constant 5.000000e-01 : f32 + %15 = arith.constant -2 : i64 + %16 = "math.fpowi"(%h_x_2, %15) : (f32, i64) -> f32 + %17 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp + %18 = arith.mulf %14, %16 : f32 + %19 = arith.mulf %18, %17 : f32 + %20 = arith.addf %9, %13 : f32 + %21 = arith.addf %20, %19 : f32 + %22 = arith.mulf %5, %21 : f32 + %23 = arith.constant 5.000000e-01 : f32 + %h_y = arith.constant 5.000000e-01 : f32 + %24 = arith.constant -2 : i64 + %25 = "math.fpowi"(%h_y, %24) : (f32, i64) -> f32 + %26 = stencil.access %f2_t0_blk[0, -1] : !stencil.temp + %27 = arith.mulf %25, %26 : f32 + %h_y_1 = arith.constant 5.000000e-01 : f32 + %28 = arith.constant -2 : i64 + %29 = "math.fpowi"(%h_y_1, %28) : (f32, i64) -> f32 + %30 = stencil.access %f2_t0_blk[0, 1] : !stencil.temp + %31 = arith.mulf %29, %30 : f32 + %32 = arith.constant -2.000000e+00 : f32 + %h_y_2 = arith.constant 5.000000e-01 : f32 + %33 = arith.constant -2 : i64 + %34 = "math.fpowi"(%h_y_2, %33) : (f32, i64) -> f32 + %35 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp + %36 = arith.mulf %32, %34 : f32 + %37 = arith.mulf %36, %35 : f32 + %38 = arith.addf %27, %31 : f32 + %39 = arith.addf %38, %37 : f32 + %40 = arith.mulf %23, %39 : f32 + %dt = arith.constant 1.000000e-01 : f32 + %41 = arith.constant -1 : i64 + %42 = "math.fpowi"(%dt, %41) : (f32, i64) -> f32 + %43 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp + %44 = arith.mulf %42, %43 : f32 + %45 = arith.addf %22, %40 : f32 + %46 = arith.addf %45, %44 : f32 + %dt_1 = arith.constant 1.000000e-01 : f32 + %47 = arith.mulf %46, %dt_1 : f32 + stencil.return %47 : f32 + } + %f2_t1_temp_1 = stencil.store %f2_t1_temp to %f2_t1 ([0, 0] : [3, 3]) : !stencil.temp to !stencil.field<[-2,5]x[-2,5]xf32> with_halo : !stencil.temp + scf.yield %f2_t1, %f2_t0 : !stencil.field<[-2,5]x[-2,5]xf32>, !stencil.field<[-2,5]x[-2,5]xf32> + } + %5 = func.call @timer_end(%0) : (f64) -> f64 + "llvm.store"(%5, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () + func.return + } + func.func private @timer_start() -> f64 + func.func private @timer_end(f64) -> f64 +} + +// CHECK: builtin.module { +// CHECK-NEXT: func.func @Kernel(%f2_vec0 : !stencil.field<[-2,5]x[-2,5]xf32>, %f2_vec1 : !stencil.field<[-2,5]x[-2,5]xf32>, %timers : !llvm.ptr) { +// CHECK-NEXT: %0 = func.call @timer_start() : () -> f64 +// CHECK-NEXT: %time_m = arith.constant 0 : index +// CHECK-NEXT: %time_M = arith.constant 1 : index +// CHECK-NEXT: %1 = arith.constant 1 : index +// CHECK-NEXT: %2 = arith.addi %time_M, %1 : index +// CHECK-NEXT: %step = arith.constant 1 : index +// CHECK-NEXT: %3, %4 = scf.for %time = %time_m to %2 step %step iter_args(%f2_t0 = %f2_vec0, %f2_t1 = %f2_vec1) -> (!stencil.field<[-2,5]x[-2,5]xf32>, !stencil.field<[-2,5]x[-2,5]xf32>) { +// CHECK-NEXT: %f2_t0_temp = stencil.load %f2_t0 : !stencil.field<[-2,5]x[-2,5]xf32> -> !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %f2_t1_temp = stencil.apply(%f2_t0_blk = %f2_t0_temp : !stencil.temp<[-1,4]x[-1,4]xf32>) -> (!stencil.temp<[0,3]x[0,3]xf32>) { +// CHECK-NEXT: %5 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %h_x = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %6 = arith.constant -2 : i64 +// CHECK-NEXT: %7 = "math.fpowi"(%h_x, %6) : (f32, i64) -> f32 +// CHECK-NEXT: %8 = stencil.access %f2_t0_blk[-1, 0] : !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %9 = arith.mulf %7, %8 : f32 +// CHECK-NEXT: %h_x_1 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %10 = arith.constant -2 : i64 +// CHECK-NEXT: %11 = "math.fpowi"(%h_x_1, %10) : (f32, i64) -> f32 +// CHECK-NEXT: %12 = stencil.access %f2_t0_blk[1, 0] : !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %13 = arith.mulf %11, %12 : f32 +// CHECK-NEXT: %14 = arith.constant -2.000000e+00 : f32 +// CHECK-NEXT: %h_x_2 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %15 = arith.constant -2 : i64 +// CHECK-NEXT: %16 = "math.fpowi"(%h_x_2, %15) : (f32, i64) -> f32 +// CHECK-NEXT: %17 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %18 = arith.mulf %14, %16 : f32 +// CHECK-NEXT: %19 = arith.mulf %18, %17 : f32 +// CHECK-NEXT: %20 = arith.addf %9, %13 : f32 +// CHECK-NEXT: %21 = arith.addf %20, %19 : f32 +// CHECK-NEXT: %22 = arith.mulf %5, %21 : f32 +// CHECK-NEXT: %23 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %h_y = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %24 = arith.constant -2 : i64 +// CHECK-NEXT: %25 = "math.fpowi"(%h_y, %24) : (f32, i64) -> f32 +// CHECK-NEXT: %26 = stencil.access %f2_t0_blk[0, -1] : !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %27 = arith.mulf %25, %26 : f32 +// CHECK-NEXT: %h_y_1 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %28 = arith.constant -2 : i64 +// CHECK-NEXT: %29 = "math.fpowi"(%h_y_1, %28) : (f32, i64) -> f32 +// CHECK-NEXT: %30 = stencil.access %f2_t0_blk[0, 1] : !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %31 = arith.mulf %29, %30 : f32 +// CHECK-NEXT: %32 = arith.constant -2.000000e+00 : f32 +// CHECK-NEXT: %h_y_2 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %33 = arith.constant -2 : i64 +// CHECK-NEXT: %34 = "math.fpowi"(%h_y_2, %33) : (f32, i64) -> f32 +// CHECK-NEXT: %35 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %36 = arith.mulf %32, %34 : f32 +// CHECK-NEXT: %37 = arith.mulf %36, %35 : f32 +// CHECK-NEXT: %38 = arith.addf %27, %31 : f32 +// CHECK-NEXT: %39 = arith.addf %38, %37 : f32 +// CHECK-NEXT: %40 = arith.mulf %23, %39 : f32 +// CHECK-NEXT: %dt = arith.constant 1.000000e-01 : f32 +// CHECK-NEXT: %41 = arith.constant -1 : i64 +// CHECK-NEXT: %42 = "math.fpowi"(%dt, %41) : (f32, i64) -> f32 +// CHECK-NEXT: %43 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp<[-1,4]x[-1,4]xf32> +// CHECK-NEXT: %44 = arith.mulf %42, %43 : f32 +// CHECK-NEXT: %45 = arith.addf %22, %40 : f32 +// CHECK-NEXT: %46 = arith.addf %45, %44 : f32 +// CHECK-NEXT: %dt_1 = arith.constant 1.000000e-01 : f32 +// CHECK-NEXT: %47 = arith.mulf %46, %dt_1 : f32 +// CHECK-NEXT: stencil.return %47 : f32 +// CHECK-NEXT: } +// CHECK-NEXT: %f2_t1_temp_1 = stencil.store %f2_t1_temp to %f2_t1 ([0, 0] : [3, 3]) : !stencil.temp<[0,3]x[0,3]xf32> to !stencil.field<[-2,5]x[-2,5]xf32> with_halo : !stencil.temp +// CHECK-NEXT: scf.yield %f2_t1, %f2_t0 : !stencil.field<[-2,5]x[-2,5]xf32>, !stencil.field<[-2,5]x[-2,5]xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %5 = func.call @timer_end(%0) : (f64) -> f64 +// CHECK-NEXT: "llvm.store"(%5, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () +// CHECK-NEXT: func.return +// CHECK-NEXT: } +// CHECK-NEXT: func.func private @timer_start() -> f64 +// CHECK-NEXT: func.func private @timer_end(f64) -> f64 +// CHECK-NEXT: } + diff --git a/tests/filecheck/xdsl_pipeline.mlir b/tests/filecheck/xdsl_pipeline.mlir new file mode 100644 index 0000000000..5243a3ca7c --- /dev/null +++ b/tests/filecheck/xdsl_pipeline.mlir @@ -0,0 +1,178 @@ +// RUN: xdsl-opt -p stencil-shape-inference,convert-stencil-to-ll-mlir,scf-parallel-loop-tiling{parallel-loop-tile-sizes=64,0},printf-to-llvm,canonicalize %s | filecheck %s + +builtin.module { + func.func @Kernel(%f2_vec0 : !stencil.field<[-2,5]x[-2,5]xf32>, %f2_vec1 : !stencil.field<[-2,5]x[-2,5]xf32>, %timers : !llvm.ptr) { + %0 = func.call @timer_start() : () -> f64 + %time_m = arith.constant 0 : index + %time_M = arith.constant 1 : index + %1 = arith.constant 1 : index + %2 = arith.addi %time_M, %1 : index + %step = arith.constant 1 : index + %3, %4 = scf.for %time = %time_m to %2 step %step iter_args(%f2_t0 = %f2_vec0, %f2_t1 = %f2_vec1) -> (!stencil.field<[-2,5]x[-2,5]xf32>, !stencil.field<[-2,5]x[-2,5]xf32>) { + %f2_t0_temp = stencil.load %f2_t0 : !stencil.field<[-2,5]x[-2,5]xf32> -> !stencil.temp + %f2_t1_temp = stencil.apply(%f2_t0_blk = %f2_t0_temp : !stencil.temp) -> (!stencil.temp) { + %5 = arith.constant 5.000000e-01 : f32 + %h_x = arith.constant 5.000000e-01 : f32 + %6 = arith.constant -2 : i64 + %7 = "math.fpowi"(%h_x, %6) : (f32, i64) -> f32 + %8 = stencil.access %f2_t0_blk[-1, 0] : !stencil.temp + %9 = arith.mulf %7, %8 : f32 + %h_x_1 = arith.constant 5.000000e-01 : f32 + %10 = arith.constant -2 : i64 + %11 = "math.fpowi"(%h_x_1, %10) : (f32, i64) -> f32 + %12 = stencil.access %f2_t0_blk[1, 0] : !stencil.temp + %13 = arith.mulf %11, %12 : f32 + %14 = arith.constant -2.000000e+00 : f32 + %h_x_2 = arith.constant 5.000000e-01 : f32 + %15 = arith.constant -2 : i64 + %16 = "math.fpowi"(%h_x_2, %15) : (f32, i64) -> f32 + %17 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp + %18 = arith.mulf %14, %16 : f32 + %19 = arith.mulf %18, %17 : f32 + %20 = arith.addf %9, %13 : f32 + %21 = arith.addf %20, %19 : f32 + %22 = arith.mulf %5, %21 : f32 + %23 = arith.constant 5.000000e-01 : f32 + %h_y = arith.constant 5.000000e-01 : f32 + %24 = arith.constant -2 : i64 + %25 = "math.fpowi"(%h_y, %24) : (f32, i64) -> f32 + %26 = stencil.access %f2_t0_blk[0, -1] : !stencil.temp + %27 = arith.mulf %25, %26 : f32 + %h_y_1 = arith.constant 5.000000e-01 : f32 + %28 = arith.constant -2 : i64 + %29 = "math.fpowi"(%h_y_1, %28) : (f32, i64) -> f32 + %30 = stencil.access %f2_t0_blk[0, 1] : !stencil.temp + %31 = arith.mulf %29, %30 : f32 + %32 = arith.constant -2.000000e+00 : f32 + %h_y_2 = arith.constant 5.000000e-01 : f32 + %33 = arith.constant -2 : i64 + %34 = "math.fpowi"(%h_y_2, %33) : (f32, i64) -> f32 + %35 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp + %36 = arith.mulf %32, %34 : f32 + %37 = arith.mulf %36, %35 : f32 + %38 = arith.addf %27, %31 : f32 + %39 = arith.addf %38, %37 : f32 + %40 = arith.mulf %23, %39 : f32 + %dt = arith.constant 1.000000e-01 : f32 + %41 = arith.constant -1 : i64 + %42 = "math.fpowi"(%dt, %41) : (f32, i64) -> f32 + %43 = stencil.access %f2_t0_blk[0, 0] : !stencil.temp + %44 = arith.mulf %42, %43 : f32 + %45 = arith.addf %22, %40 : f32 + %46 = arith.addf %45, %44 : f32 + %dt_1 = arith.constant 1.000000e-01 : f32 + %47 = arith.mulf %46, %dt_1 : f32 + stencil.return %47 : f32 + } + %f2_t1_temp_1 = stencil.store %f2_t1_temp to %f2_t1 ([0, 0] : [3, 3]) : !stencil.temp to !stencil.field<[-2,5]x[-2,5]xf32> with_halo : !stencil.temp + scf.yield %f2_t1, %f2_t0 : !stencil.field<[-2,5]x[-2,5]xf32>, !stencil.field<[-2,5]x[-2,5]xf32> + } + %5 = func.call @timer_end(%0) : (f64) -> f64 + "llvm.store"(%5, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () + func.return + } + func.func private @timer_start() -> f64 + func.func private @timer_end(f64) -> f64 +} + + +// CHECK: builtin.module { +// CHECK-NEXT: func.func @Kernel(%f2_vec0 : memref<7x7xf32>, %f2_vec1 : memref<7x7xf32>, %timers : !llvm.ptr) { +// CHECK-NEXT: %0 = func.call @timer_start() : () -> f64 +// CHECK-NEXT: %time_m = arith.constant 0 : index +// CHECK-NEXT: %time_M = arith.constant 1 : index +// CHECK-NEXT: %1 = arith.constant 1 : index +// CHECK-NEXT: %2 = arith.addi %time_M, %1 : index +// CHECK-NEXT: %step = arith.constant 1 : index +// CHECK-NEXT: %3, %4 = scf.for %time = %time_m to %2 step %step iter_args(%f2_t0 = %f2_vec0, %f2_t1 = %f2_vec1) -> (memref<7x7xf32>, memref<7x7xf32>) { +// CHECK-NEXT: %f2_t1_storeview = "memref.subview"(%f2_t1) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (memref<7x7xf32>) -> memref<3x3xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %f2_t0_loadview = "memref.subview"(%f2_t0) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (memref<7x7xf32>) -> memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %5 = arith.constant 0 : index +// CHECK-NEXT: %6 = arith.constant 0 : index +// CHECK-NEXT: %7 = arith.constant 1 : index +// CHECK-NEXT: %8 = arith.constant 1 : index +// CHECK-NEXT: %9 = arith.constant 3 : index +// CHECK-NEXT: %10 = arith.constant 3 : index +// CHECK-NEXT: %11 = arith.constant 0 : index +// CHECK-NEXT: %12 = arith.constant 64 : index +// CHECK-NEXT: %13 = arith.muli %7, %12 : index +// CHECK-NEXT: "scf.parallel"(%5, %9, %13) <{"operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^0(%14 : index): +// CHECK-NEXT: %15 = "affine.min"(%12, %9, %14) <{"map" = affine_map<(d0, d1, d2) -> (d0, (d1 + (d2 * -1)))>}> : (index, index, index) -> index +// CHECK-NEXT: "scf.parallel"(%11, %6, %15, %10, %7, %8) <{"operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^1(%16 : index, %17 : index): +// CHECK-NEXT: %18 = arith.addi %14, %16 : index +// CHECK-NEXT: %19 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %h_x = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %20 = arith.constant -2 : i64 +// CHECK-NEXT: %21 = "math.fpowi"(%h_x, %20) : (f32, i64) -> f32 +// CHECK-NEXT: %22 = arith.constant -1 : index +// CHECK-NEXT: %23 = arith.addi %18, %22 : index +// CHECK-NEXT: %24 = memref.load %f2_t0_loadview[%23, %17] : memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %25 = arith.mulf %21, %24 : f32 +// CHECK-NEXT: %h_x_1 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %26 = arith.constant -2 : i64 +// CHECK-NEXT: %27 = "math.fpowi"(%h_x_1, %26) : (f32, i64) -> f32 +// CHECK-NEXT: %28 = arith.constant 1 : index +// CHECK-NEXT: %29 = arith.addi %18, %28 : index +// CHECK-NEXT: %30 = memref.load %f2_t0_loadview[%29, %17] : memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %31 = arith.mulf %27, %30 : f32 +// CHECK-NEXT: %32 = arith.constant -2.000000e+00 : f32 +// CHECK-NEXT: %h_x_2 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %33 = arith.constant -2 : i64 +// CHECK-NEXT: %34 = "math.fpowi"(%h_x_2, %33) : (f32, i64) -> f32 +// CHECK-NEXT: %35 = memref.load %f2_t0_loadview[%18, %17] : memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %36 = arith.mulf %32, %34 : f32 +// CHECK-NEXT: %37 = arith.mulf %36, %35 : f32 +// CHECK-NEXT: %38 = arith.addf %25, %31 : f32 +// CHECK-NEXT: %39 = arith.addf %38, %37 : f32 +// CHECK-NEXT: %40 = arith.mulf %19, %39 : f32 +// CHECK-NEXT: %41 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %h_y = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %42 = arith.constant -2 : i64 +// CHECK-NEXT: %43 = "math.fpowi"(%h_y, %42) : (f32, i64) -> f32 +// CHECK-NEXT: %44 = arith.constant -1 : index +// CHECK-NEXT: %45 = arith.addi %17, %44 : index +// CHECK-NEXT: %46 = memref.load %f2_t0_loadview[%18, %45] : memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %47 = arith.mulf %43, %46 : f32 +// CHECK-NEXT: %h_y_1 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %48 = arith.constant -2 : i64 +// CHECK-NEXT: %49 = "math.fpowi"(%h_y_1, %48) : (f32, i64) -> f32 +// CHECK-NEXT: %50 = arith.constant 1 : index +// CHECK-NEXT: %51 = arith.addi %17, %50 : index +// CHECK-NEXT: %52 = memref.load %f2_t0_loadview[%18, %51] : memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %53 = arith.mulf %49, %52 : f32 +// CHECK-NEXT: %54 = arith.constant -2.000000e+00 : f32 +// CHECK-NEXT: %h_y_2 = arith.constant 5.000000e-01 : f32 +// CHECK-NEXT: %55 = arith.constant -2 : i64 +// CHECK-NEXT: %56 = "math.fpowi"(%h_y_2, %55) : (f32, i64) -> f32 +// CHECK-NEXT: %57 = memref.load %f2_t0_loadview[%18, %17] : memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %58 = arith.mulf %54, %56 : f32 +// CHECK-NEXT: %59 = arith.mulf %58, %57 : f32 +// CHECK-NEXT: %60 = arith.addf %47, %53 : f32 +// CHECK-NEXT: %61 = arith.addf %60, %59 : f32 +// CHECK-NEXT: %62 = arith.mulf %41, %61 : f32 +// CHECK-NEXT: %dt = arith.constant 1.000000e-01 : f32 +// CHECK-NEXT: %63 = arith.constant -1 : i64 +// CHECK-NEXT: %64 = "math.fpowi"(%dt, %63) : (f32, i64) -> f32 +// CHECK-NEXT: %65 = memref.load %f2_t0_loadview[%18, %17] : memref<5x5xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: %66 = arith.mulf %64, %65 : f32 +// CHECK-NEXT: %67 = arith.addf %40, %62 : f32 +// CHECK-NEXT: %68 = arith.addf %67, %66 : f32 +// CHECK-NEXT: %dt_1 = arith.constant 1.000000e-01 : f32 +// CHECK-NEXT: %69 = arith.mulf %68, %dt_1 : f32 +// CHECK-NEXT: memref.store %69, %f2_t1_storeview[%18, %17] : memref<3x3xf32, strided<[7, 1], offset: 16>> +// CHECK-NEXT: scf.yield +// CHECK-NEXT: }) : (index, index, index, index, index, index) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: }) : (index, index, index) -> () +// CHECK-NEXT: scf.yield %f2_t1, %f2_t0 : memref<7x7xf32>, memref<7x7xf32> +// CHECK-NEXT: } +// CHECK-NEXT: %70 = func.call @timer_end(%0) : (f64) -> f64 +// CHECK-NEXT: "llvm.store"(%70, %timers) <{"ordering" = 0 : i64}> : (f64, !llvm.ptr) -> () +// CHECK-NEXT: func.return +// CHECK-NEXT: } +// CHECK-NEXT: func.func private @timer_start() -> f64 +// CHECK-NEXT: func.func private @timer_end(f64) -> f64 +// CHECK-NEXT: } +// CHECK-NEXT: diff --git a/tests/test_xdsl_base.py b/tests/test_xdsl_base.py index 0625a8d1e6..6b50731d0f 100644 --- a/tests/test_xdsl_base.py +++ b/tests/test_xdsl_base.py @@ -8,7 +8,7 @@ from examples.seismic.source import RickerSource, TimeAxis from xdsl.dialects.scf import For, Yield -from xdsl.dialects.arith import Addi +from xdsl.dialects.arith import Addi, Addf, Mulf from xdsl.dialects.arith import Constant as xdslconstant from xdsl.dialects.func import Call, Return from xdsl.dialects.stencil import FieldType, ApplyOp, LoadOp, StoreOp @@ -22,8 +22,12 @@ def test_xdsl_I(): u = TimeFunction(name='u', grid=grid) eq = Eq(u.forward, u + 1) - op = Operator([eq], opt='xdsl') - op.apply(time_M=1) + opx = Operator([eq], opt='xdsl') + opx.apply(time_M=1) + + assert len([op for op in opx._module.walk() if isinstance(op, Addi)]) == 1 + assert len([op for op in opx._module.walk() if isinstance(op, Addf)]) == 1 + assert (u.data[1, :] == 1.).all() assert (u.data[0, :] == 2.).all() @@ -89,8 +93,11 @@ def test_diffusion_2D(): f2 = TimeFunction(name='f2', grid=grid, space_order=2) f2.data[:] = 1 eqn = Eq(f2.dt, 0.5 * f2.laplace) - op = Operator(Eq(f2.forward, solve(eqn, f2.forward)), opt='xdsl') - op.apply(time_M=1, dt=0.1) + opx = Operator(Eq(f2.forward, solve(eqn, f2.forward)), opt='xdsl') + opx.apply(time_M=1, dt=0.1) + + assert len([op for op in opx._module.walk() if isinstance(op, Addf)]) == 6 + assert len([op for op in opx._module.walk() if isinstance(op, Mulf)]) == 12 assert np.isclose(f.data, f2.data, rtol=1e-06).all()