Skip to content

Commit

Permalink
Global with no comm is a Constant (#701)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Connor Ward <[email protected]>
Co-authored-by: David A. Ham <[email protected]>
  • Loading branch information
3 people authored Jun 14, 2023
1 parent edae288 commit d230953
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 211 deletions.
24 changes: 19 additions & 5 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyop2.caching import Cached
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.types import IterationRegion
from pyop2.types import IterationRegion, Constant, READ
from pyop2.utils import cached_property, get_petsc_dir


Expand Down Expand Up @@ -277,13 +277,27 @@ def __init__(self, local_kernel, arguments, *,
return

if not len(local_kernel.accesses) == len(arguments):
raise ValueError("Number of arguments passed to the local "
"and global kernels do not match")
raise ValueError(
"Number of arguments passed to the local and global kernels"
" do not match"
)

if any(
isinstance(garg, Constant) and larg.access is not READ
for larg, garg in zip(local_kernel.arguments, arguments)
):
raise ValueError(
"Constants can only ever be read in a parloop, not modified"
)

if pass_layer_arg and not extruded:
raise ValueError("Cannot request layer argument for non-extruded iteration")
raise ValueError(
"Cannot request layer argument for non-extruded iteration"
)
if constant_layers and not extruded:
raise ValueError("Cannot request constant_layers argument for non-extruded iteration")
raise ValueError(
"Cannot request constant_layers argument for non-extruded iteration"
)

self.local_kernel = local_kernel
self.arguments = arguments
Expand Down
4 changes: 2 additions & 2 deletions pyop2/op2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@
from pyop2.logger import debug, info, warning, error, critical, set_log_level
from pyop2.mpi import MPI, COMM_WORLD, collective

from pyop2.types import (
from pyop2.types import ( # noqa: F401
Set, ExtrudedSet, MixedSet, Subset, DataSet, MixedDataSet,
Map, MixedMap, PermutedMap, ComposedMap, Sparsity, Halo,
Global, GlobalDataSet,
Global, Constant, GlobalDataSet,
Dat, MixedDat, DatView, Mat
)
from pyop2.types import (READ, WRITE, RW, INC, MIN, MAX,
Expand Down
5 changes: 4 additions & 1 deletion pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,15 @@ def _iop_kernel(self, op, globalp, other_is_self, dtype):

def _iop(self, other, op):
from pyop2.parloop import parloop
from pyop2.types.glob import Global
from pyop2.types.glob import Global, Constant

globalp = False
if np.isscalar(other):
other = Global(1, data=other, comm=self.comm)
globalp = True
elif isinstance(other, Constant):
other = Global(other, comm=self.comm)
globalp = True
elif other is not self:
self._check_shape(other)
args = [self(Access.INC)]
Expand Down
Loading

0 comments on commit d230953

Please sign in to comment.