Skip to content

Commit

Permalink
flake8 add some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebisbas committed Nov 17, 2023
1 parent 15aaeab commit bd3b60a
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 107 deletions.
1 change: 1 addition & 0 deletions devito/ir/ietxdsl/cluster_to_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from devito.ir.ietxdsl import iet_ssa
from devito.ir.ietxdsl.ietxdsl_functions import dtypes_to_xdsltypes

# flake8: noqa

class ExtractDevitoStencilConversion:
"""
Expand Down
12 changes: 6 additions & 6 deletions devito/ir/ietxdsl/iet_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ def get_llvm_struct_type():
return llvm.LLVMStructType.from_type_list([
llvm.LLVMPointerType.opaque(), # data
llvm.LLVMPointerType.typed(builtin.i32), # size
#llvm.LLVMPointerType.typed(builtin.i32), # npsize
#llvm.LLVMPointerType.typed(builtin.i32), # dsize
#llvm.LLVMPointerType.typed(builtin.i32), # hsize
#llvm.LLVMPointerType.typed(builtin.i32), # hofs
#llvm.LLVMPointerType.typed(builtin.i32), # oofs
#llvm.LLVMPointerType.opaque(), # dmap
# llvm.LLVMPointerType.typed(builtin.i32), # npsize
# llvm.LLVMPointerType.typed(builtin.i32), # dsize
# llvm.LLVMPointerType.typed(builtin.i32), # hsize
# llvm.LLVMPointerType.typed(builtin.i32), # hofs
# llvm.LLVMPointerType.typed(builtin.i32), # oofs
# llvm.LLVMPointerType.opaque(), # dmap
])


Expand Down
2 changes: 2 additions & 0 deletions devito/operator/xdsl_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from xdsl.printer import Printer

# flake8: noqa

__all__ = ['XDSLOperator']

Expand All @@ -62,6 +63,7 @@
}
"""


def generate_tiling_arg(nb_tiled_dims: int):
"""
Generate the tile-sizes arg for the convert-stencil-to-ll-mlir pass. Generating no argument if the diled_dims arg is 0
Expand Down
2 changes: 2 additions & 0 deletions fast/diffusion_2D_wBCs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

mpiconf = configuration['mpi']

# flake8: noqa

# Some variable declarations
nx, ny = args.shape
nt = args.nt
Expand Down
16 changes: 8 additions & 8 deletions fast/setup_wave2d.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# Script to save initial data for the Acoustic wave execution benchmark
# Based on the implementation of the Devito acoustic example implementation
# Not using Devito's source injection abstraction
import sys
import numpy as np

from devito import (TimeFunction, Eq, Operator, solve, norm,
XDSLOperator, configuration)
from devito import (TimeFunction, Eq, Operator, solve, configuration)
from examples.seismic import RickerSource
from examples.seismic import Model, TimeAxis, plot_image
from examples.seismic import Model, TimeAxis
from fast.bench_utils import plot_2dfunc
from devito.tools import as_tuple

Expand Down Expand Up @@ -40,10 +38,12 @@
nt = args.nt

shape = (args.shape) # Number of grid point (nx, ny, nz)
spacing = as_tuple(10.0 for _ in range(len(shape))) # Grid spacing in m. The domain size is now 1km by 1km
origin = as_tuple(0.0 for _ in range(len(shape))) # What is the location of the top left corner.
# This is necessary to define
# the absolute location of the source and receivers
# Grid spacing in m. The domain size is now 1km by 1km
spacing = as_tuple(10.0 for _ in range(len(shape)))
# What is the location of the top left corner.
origin = as_tuple(0.0 for _ in range(len(shape)))
# This is necessary to define the absolute location of the
# source and receivers

# Define a velocity profile. The velocity is in km/s
v = np.empty(shape, dtype=np.float32)
Expand Down
1 change: 1 addition & 0 deletions fast/setup_wave3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
parser.add_argument("-xdsl", "--xdsl", default=False, type=bool, help="xDSL run")
args = parser.parse_args()

# flake8: noqa

mpiconf = configuration['mpi']

Expand Down
93 changes: 0 additions & 93 deletions tests/test_xdsl_iet.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,99 +54,6 @@ def test_powi():
])


@pytest.mark.xfail(reason="Deprecated, will be dropped")
def test_devito_iet():
grid = Grid(shape=(3, 3))
u = TimeFunction(name='u', grid=grid)
eq = Eq(u.forward, u + 1)
op = Operator([eq])

iters = retrieve_iteration_tree(op.body)

t_limits = as_tuple([str(i) for i in iters[0][0].limits])
t_props = [str(i) for i in iters[0][0].properties]

x_limits = as_tuple([str(i) for i in iters[0][1].limits])
x_props = [str(i) for i in iters[0][1].properties]

y_limits = [str(i) for i in iters[0][2].limits]
y_props = [str(i) for i in iters[0][2].properties]

ctx = MLContext()
Builtin(ctx)
iet = IET(ctx)

mod = ModuleOp.from_region_or_ops([
Callable.get("kernel", ["u"], ["u"],["struct dataobj*"], ["restrict"], "int", "",
Block.from_callable([i32], lambda u: [
Iteration.get(t_props, t_limits, iters[0][0].dim.name,
Block.from_callable([i32, i32, i32],
lambda time, t0, t1: [
Iteration.get(x_props, x_limits, iters[0][1].dim.name,
Block.from_callable([i32], lambda x: [
Iteration.get(y_props, y_limits, iters[0][2].dim.name,
Block.from_callable([i32], lambda y: [
cst1 := Constant.from_int_and_width(1, i32),
x1 := Addi.get(x, cst1),
y1 := Addi.get(y, cst1),
#ut0 := Idx.get(u, t0),
#ut0x1 := Idx.get(ut0, x1),
#ut0x1y1 := Idx.get(ut0x1, y1),
#rhs := Addi.get(ut0x1y1, cst1),
#ut1 := Idx.get(u, t1),
#ut1x1 := Idx.get(ut1, x1),
#lhs := Idx.get(ut1x1, y1),
#Assign.build([lhs, rhs])
]))
]))
]))
]))
])

printer = Printer()
printer.print_op(mod)


@pytest.mark.xfail(reason="Deprecated, will be dropped")
def test_mfe_memref():
ctx = MLContext()
Builtin(ctx)
iet = IET(ctx)

memref_f32_rank2 = memref.MemRefType.from_element_type_and_shape(
f32, [-1, -1])

mod = ModuleOp.from_region_or_ops([
Callable.get(
"kernel", ["u"],["u"],["struct dataobj*"], ["restrict"], "int", "",
Block.from_callable([memref_f32_rank2], lambda u: [
Iteration
.get(["affine", "sequential"], ("time_m", "time_M", "1"),"time_loop",
Block.from_callable([
i32, i32, i32
], lambda time, t0, t1: [
Iteration.get(
[
"affine",
"parallel", "skewable", "vector-dim"
], ("y_m", "y_M", "1"),"y_loop",
Block.from_callable([i32], lambda y: [
cst1 := Constant.from_int_and_width(1, i32),
y1 := Addi.get(y, cst1),
ut0 := memref.Load.get(u, [t0, y1]),
# ut0 := Idx.get(u, t0),
rhs := Addi.get(ut0, cst1),
memref.Store.get(rhs, u, [t1, y1])
# Assign.build([lhs, rhs])
]))
]))
]))
])

printer = Printer()
# import pdb;pdb.set_trace()
printer.print_op(mod)

def test_mfe2():

mod = ModuleOp([
Expand Down

0 comments on commit bd3b60a

Please sign in to comment.