Skip to content

Commit

Permalink
with_container_arithmetic: Rename arguments to signal who broadcasts …
Browse files Browse the repository at this point in the history
…across who

Names suggested by @majosm
  • Loading branch information
inducer committed Sep 4, 2024
1 parent 8b1b795 commit 510dc1b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 54 deletions.
161 changes: 113 additions & 48 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,34 +159,40 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet


def with_container_arithmetic(
*,
bcast_number: bool = True,
_bcast_actx_array_type: Optional[bool] = None,
bcast_obj_array: Optional[bool] = None,
bcast_numpy_array: bool = False,
bcast_container_types: Optional[Tuple[type, ...]] = None,
arithmetic: bool = True,
matmul: bool = False,
bitwise: bool = False,
shift: bool = False,
_cls_has_array_context_attr: Optional[bool] = None,
eq_comparison: Optional[bool] = None,
rel_comparison: Optional[bool] = None) -> Callable[[type], type]:
*,
number_bcasts_across: Optional[bool] = None,
bcasts_across_obj_array: Optional[bool] = None,
container_types_bcast_across: Optional[Tuple[type, ...]] = None,
arithmetic: bool = True,
matmul: bool = False,
bitwise: bool = False,
shift: bool = False,
_cls_has_array_context_attr: Optional[bool] = None,
eq_comparison: Optional[bool] = None,
rel_comparison: Optional[bool] = None,

# deprecated:
bcast_number: Optional[bool] = None,
bcast_obj_array: Optional[bool] = None,
bcast_numpy_array: bool = False,
_bcast_actx_array_type: Optional[bool] = None,
bcast_container_types: Optional[Tuple[type, ...]] = None,
) -> Callable[[type], type]:
"""A class decorator that implements built-in operators for array containers
by propagating the operations to the elements of the container.
:arg bcast_number: If *True*, numbers broadcast over the container
:arg number_bcasts_across: If *True*, numbers broadcast over the container
(with the container as the 'outer' structure).
:arg bcast_obj_array: If *True*, this container will be broadcast
:arg bcasts_across_obj_array: If *True*, this container will be broadcast
across :mod:`numpy` object arrays
(with the object array as the 'outer' structure).
Add :class:`numpy.ndarray` to *bcast_container_types* to achieve
Add :class:`numpy.ndarray` to *container_types_bcast_across* to achieve
the 'reverse' broadcasting.
:arg bcast_container_types: A sequence of container types that will broadcast
:arg container_types_bcast_across: A sequence of container types that will broadcast
across this container, with this container as the 'outer' structure.
:class:`numpy.ndarray` is permitted to be part of this sequence to
indicate that object arrays (and *only* object arrays) will be broadcasat.
In this case, *bcast_obj_array* must be *False*.
indicate that object arrays (and *only* object arrays) will be broadcast.
In this case, *bcasts_across_obj_array* must be *False*.
:arg arithmetic: Implement the conventional arithmetic operators, including
``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as
:func:`abs`.
Expand Down Expand Up @@ -241,8 +247,71 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):

# {{{ handle inputs

if bcast_obj_array is None:
raise TypeError("bcast_obj_array must be specified")
if rel_comparison and eq_comparison is None:
eq_comparison = True

if eq_comparison is None:
raise TypeError("eq_comparison must be specified")

# {{{ handle bcast_number

if bcast_number is not None:
if number_bcasts_across is not None:
raise TypeError(
"may specify at most one of 'bcast_number' and "
"'number_bcasts_across'")

warn("'bcast_number' is deprecated and will be unsupported from 2025. "
"Use 'number_bcasts_across', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
number_bcasts_across = bcast_number
else:
if number_bcasts_across is None:
number_bcasts_across = True

del bcast_number

# }}}

# {{{ handle bcast_obj_array

if bcast_obj_array is not None:
if bcasts_across_obj_array is not None:
raise TypeError(
"may specify at most one of 'bcast_obj_array' and "
"'bcasts_across_obj_array'")

warn("'bcast_obj_array' is deprecated and will be unsupported from 2025. "
"Use 'bcasts_across_obj_array', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
bcasts_across_obj_array = bcast_obj_array
else:
if bcasts_across_obj_array is None:
raise TypeError("bcasts_across_obj_array must be specified")

del bcast_obj_array

# }}}

# {{{ handle bcast_container_types

if bcast_container_types is not None:
if container_types_bcast_across is not None:
raise TypeError(
"may specify at most one of 'bcast_container_types' and "
"'container_types_bcast_across'")

warn("'bcast_container_types' is deprecated and will be unsupported from 2025. "
"Use 'container_types_bcast_across', with equivalent meaning.",
DeprecationWarning, stacklevel=2)
container_types_bcast_across = bcast_container_types
else:
if container_types_bcast_across is None:
container_types_bcast_across = ()

del bcast_container_types

# }}}

if rel_comparison is None:
raise TypeError("rel_comparison must be specified")
Expand All @@ -255,36 +324,27 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'"
" cannot be both set.")

if rel_comparison and eq_comparison is None:
eq_comparison = True

if eq_comparison is None:
raise TypeError("eq_comparison must be specified")

if not bcast_obj_array and bcast_numpy_array:
if not bcasts_across_obj_array and bcast_numpy_array:
raise TypeError("bcast_obj_array must be set if bcast_numpy_array is")

if bcast_numpy_array:
def numpy_pred(name: str) -> str:
return f"is_numpy_array({name})"
elif bcast_obj_array:
elif bcasts_across_obj_array:
def numpy_pred(name: str) -> str:
return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
else:
def numpy_pred(name: str) -> str:
return "False" # optimized away

if bcast_container_types is None:
bcast_container_types = ()

if np.ndarray in bcast_container_types and bcast_obj_array:
if np.ndarray in container_types_bcast_across and bcasts_across_obj_array:
raise ValueError("If numpy.ndarray is part of bcast_container_types, "
"bcast_obj_array must be False.")

numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray]
bcast_container_types = tuple(
container_types_bcast_across = tuple(
new_ct
for old_ct in bcast_container_types
for old_ct in container_types_bcast_across
for new_ct in
(numpy_check_types
if old_ct is np.ndarray
Expand Down Expand Up @@ -334,7 +394,7 @@ def wrap(cls: Any) -> Any:

if bcast_actx_array_type is None:
if cls_has_array_context_attr:
if bcast_number:
if number_bcasts_across:
bcast_actx_array_type = cls_has_array_context_attr
else:
bcast_actx_array_type = False
Expand Down Expand Up @@ -409,14 +469,14 @@ def is_numpy_array(arg):
""")
gen("")

if bcast_container_types:
for i, bct in enumerate(bcast_container_types):
if container_types_bcast_across:
for i, bct in enumerate(container_types_bcast_across):
gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}")
gen("")
outer_bcast_type_names = tuple(
f"_bctype{i}" for i in range(len(bcast_container_types)))
if bcast_number:
outer_bcast_type_names += ("Number",)
container_type_names_bcast_across = tuple(
f"_bctype{i}" for i in range(len(container_types_bcast_across)))
if number_bcasts_across:
container_type_names_bcast_across += ("Number",)

def same_key(k1: T, k2: T) -> T:
assert k1 == k2
Expand All @@ -428,9 +488,14 @@ def tup_str(t: Tuple[str, ...]) -> str:
else:
return "({},)".format(", ".join(t))

gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}")
gen(f"cls._outer_bcast_types = {tup_str(container_type_names_bcast_across)}")
gen("cls._container_types_bcast_across = "
f"{tup_str(container_type_names_bcast_across)}")

gen(f"cls._bcast_numpy_array = {bcast_numpy_array}")
gen(f"cls._bcast_obj_array = {bcast_obj_array}")

gen(f"cls._bcast_obj_array = {bcasts_across_obj_array}")
gen(f"cls._bcasts_across_obj_array = {bcasts_across_obj_array}")
gen("")

# {{{ unary operators
Expand Down Expand Up @@ -535,9 +600,9 @@ def {fname}(arg1):
result[i] = {op_str.format("arg1", "arg2[i]")}
return result
if {bool(outer_bcast_type_names)}: # optimized away
if {bool(container_type_names_bcast_across)}: # optimized away
if isinstance(arg2,
{tup_str(outer_bcast_type_names
{tup_str(container_type_names_bcast_across
+ bcast_actx_ary_types)}):
if __debug__:
if isinstance(arg2, {tup_str(bcast_actx_ary_types)}):
Expand Down Expand Up @@ -584,9 +649,9 @@ def {fname}(arg2, arg1):
for i in np.ndindex(arg1.shape):
result[i] = {op_str.format("arg1[i]", "arg2")}
return result
if {bool(outer_bcast_type_names)}: # optimized away
if {bool(container_type_names_bcast_across)}: # optimized away
if isinstance(arg1,
{tup_str(outer_bcast_type_names
{tup_str(container_type_names_bcast_across
+ bcast_actx_ary_types)}):
if __debug__:
if isinstance(arg1,
Expand Down
12 changes: 6 additions & 6 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _acf():
# {{{ stand-in DOFArray implementation

@with_container_arithmetic(
bcast_obj_array=True,
bcasts_across_obj_array=True,
bitwise=True,
rel_comparison=True,
_cls_has_array_context_attr=True,
Expand Down Expand Up @@ -208,7 +208,7 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type:

# {{{ nested containers

@with_container_arithmetic(bcast_obj_array=False,
@with_container_arithmetic(bcasts_across_obj_array=False,
eq_comparison=False, rel_comparison=False,
_cls_has_array_context_attr=True,
_bcast_actx_array_type=False)
Expand All @@ -231,7 +231,7 @@ def array_context(self):


@with_container_arithmetic(
bcast_obj_array=False,
bcasts_across_obj_array=False,
bcast_container_types=(DOFArray, np.ndarray),
matmul=True,
rel_comparison=True,
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def test_norm_ord_none(actx_factory, ndim):

# {{{ test_actx_compile helpers

@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True)
@dataclass_array_container
@dataclass(frozen=True)
class Velocity2D:
Expand Down Expand Up @@ -1355,7 +1355,7 @@ def test_container_equality(actx_factory):
# {{{ test_no_leaf_array_type_broadcasting

@with_container_arithmetic(
bcast_obj_array=True,
bcasts_across_obj_array=True,
rel_comparison=True,
_cls_has_array_context_attr=True,
_bcast_actx_array_type=False)
Expand Down Expand Up @@ -1459,7 +1459,7 @@ def equal(a, b):

# {{{ test_array_container_with_numpy

@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True)
@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True)
@dataclass_array_container
@dataclass(frozen=True)
class ArrayContainerWithNumpy:
Expand Down

0 comments on commit 510dc1b

Please sign in to comment.