Skip to content

Commit

Permalink
codegen: refactor Map.indexed()
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Jun 14, 2024
1 parent a52b998 commit b9c031a
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions pyop2/codegen/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def shape(self):
def dtype(self):
return self.values.dtype

def indexed(self, multiindex, layer=None, permute=lambda x: x):
@property
def _permute(self):
return lambda x: x

def indexed(self, multiindex, layer=None):
n, i, f = multiindex
if layer is not None and self.offset is not None:
# For extruded mesh, prefetch the indirections for each map, so that they don't
Expand All @@ -84,7 +88,7 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
base_key = None
if base_key not in self.prefetch:
j = Index()
base = Indexed(self.values, (n, permute(j)))
base = Indexed(self.values, (n, self._permute(j)))
self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))

base = self.prefetch[base_key]
Expand Down Expand Up @@ -122,17 +126,17 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x):
return Indexed(self.prefetch[key], (f, i)), (f, i)
else:
assert f.extent == 1 or f.extent is None
base = Indexed(self.values, (n, permute(i)))
base = Indexed(self.values, (n, self._permute(i)))
return base, (f, i)

def indexed_vector(self, n, shape, layer=None, permute=lambda x: x):
def indexed_vector(self, n, shape, layer=None):
shape = self.shape[1:] + shape
if self.interior_horizontal:
shape = (2, ) + shape
else:
shape = (1, ) + shape
f, i, j = (Index(e) for e in shape)
base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute)
base, (f, i) = self.indexed((n, i, f), layer=layer)
init = Sum(Product(base, Literal(numpy.int32(j.extent))), j)
pack = Materialise(PackInst(), init, MultiIndex(f, i, j))
multiindex = tuple(Index(e) for e in pack.shape)
Expand Down Expand Up @@ -168,13 +172,9 @@ def __init__(self, map_, permutation):
self.offset_quotient = map_.offset_quotient
self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}")

def indexed(self, multiindex, layer=None):
permute = lambda x: Indexed(self.permutation, (x,))
return super().indexed(multiindex, layer=layer, permute=permute)

def indexed_vector(self, n, shape, layer=None):
permute = lambda x: Indexed(self.permutation, (x,))
return super().indexed_vector(n, shape, layer=layer, permute=permute)
@property
def _permute(self):
return lambda x: Indexed(self.permutation, (x,))


class CMap(Map):
Expand Down

0 comments on commit b9c031a

Please sign in to comment.