diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 66e10ff0..9366b260 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -159,34 +159,40 @@ class ComplainingNumpyNonObjectArray(metaclass=ComplainingNumpyNonObjectArrayMet def with_container_arithmetic( - *, - bcast_number: bool = True, - _bcast_actx_array_type: Optional[bool] = None, - bcast_obj_array: Optional[bool] = None, - bcast_numpy_array: bool = False, - bcast_container_types: Optional[Tuple[type, ...]] = None, - arithmetic: bool = True, - matmul: bool = False, - bitwise: bool = False, - shift: bool = False, - _cls_has_array_context_attr: Optional[bool] = None, - eq_comparison: Optional[bool] = None, - rel_comparison: Optional[bool] = None) -> Callable[[type], type]: + *, + number_bcasts_across: Optional[bool] = None, + bcasts_across_obj_array: Optional[bool] = None, + container_types_bcast_across: Optional[Tuple[type, ...]] = None, + arithmetic: bool = True, + matmul: bool = False, + bitwise: bool = False, + shift: bool = False, + _cls_has_array_context_attr: Optional[bool] = None, + eq_comparison: Optional[bool] = None, + rel_comparison: Optional[bool] = None, + + # deprecated: + bcast_number: Optional[bool] = None, + bcast_obj_array: Optional[bool] = None, + bcast_numpy_array: bool = False, + _bcast_actx_array_type: Optional[bool] = None, + bcast_container_types: Optional[Tuple[type, ...]] = None, + ) -> Callable[[type], type]: """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. - :arg bcast_number: If *True*, numbers broadcast over the container + :arg number_bcasts_across: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). - :arg bcast_obj_array: If *True*, this container will be broadcast + :arg bcasts_across_obj_array: If *True*, this container will be broadcast across :mod:`numpy` object arrays (with the object array as the 'outer' structure). - Add :class:`numpy.ndarray` to *bcast_container_types* to achieve + Add :class:`numpy.ndarray` to *container_types_bcast_across* to achieve the 'reverse' broadcasting. - :arg bcast_container_types: A sequence of container types that will broadcast + :arg container_types_bcast_across: A sequence of container types that will broadcast across this container, with this container as the 'outer' structure. :class:`numpy.ndarray` is permitted to be part of this sequence to - indicate that object arrays (and *only* object arrays) will be broadcasat. - In this case, *bcast_obj_array* must be *False*. + indicate that object arrays (and *only* object arrays) will be broadcast. + In this case, *bcasts_across_obj_array* must be *False*. :arg arithmetic: Implement the conventional arithmetic operators, including ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as :func:`abs`. @@ -241,8 +247,71 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): # {{{ handle inputs - if bcast_obj_array is None: - raise TypeError("bcast_obj_array must be specified") + if rel_comparison and eq_comparison is None: + eq_comparison = True + + if eq_comparison is None: + raise TypeError("eq_comparison must be specified") + + # {{{ handle bcast_number + + if bcast_number is not None: + if number_bcasts_across is not None: + raise TypeError( + "may specify at most one of 'bcast_number' and " + "'number_bcasts_across'") + + warn("'bcast_number' is deprecated and will be unsupported from 2025. " + "Use 'number_bcasts_across', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + number_bcasts_across = bcast_number + else: + if number_bcasts_across is None: + number_bcasts_across = True + + del bcast_number + + # }}} + + # {{{ handle bcast_obj_array + + if bcast_obj_array is not None: + if bcasts_across_obj_array is not None: + raise TypeError( + "may specify at most one of 'bcast_obj_array' and " + "'bcasts_across_obj_array'") + + warn("'bcast_obj_array' is deprecated and will be unsupported from 2025. " + "Use 'bcasts_across_obj_array', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + bcasts_across_obj_array = bcast_obj_array + else: + if bcasts_across_obj_array is None: + raise TypeError("bcasts_across_obj_array must be specified") + + del bcast_obj_array + + # }}} + + # {{{ handle bcast_container_types + + if bcast_container_types is not None: + if container_types_bcast_across is not None: + raise TypeError( + "may specify at most one of 'bcast_container_types' and " + "'container_types_bcast_across'") + + warn("'bcast_container_types' is deprecated and will be unsupported from 2025. " + "Use 'container_types_bcast_across', with equivalent meaning.", + DeprecationWarning, stacklevel=2) + container_types_bcast_across = bcast_container_types + else: + if container_types_bcast_across is None: + container_types_bcast_across = () + + del bcast_container_types + + # }}} if rel_comparison is None: raise TypeError("rel_comparison must be specified") @@ -255,36 +324,27 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" " cannot be both set.") - if rel_comparison and eq_comparison is None: - eq_comparison = True - - if eq_comparison is None: - raise TypeError("eq_comparison must be specified") - - if not bcast_obj_array and bcast_numpy_array: + if not bcasts_across_obj_array and bcast_numpy_array: raise TypeError("bcast_obj_array must be set if bcast_numpy_array is") if bcast_numpy_array: def numpy_pred(name: str) -> str: return f"is_numpy_array({name})" - elif bcast_obj_array: + elif bcasts_across_obj_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" else: def numpy_pred(name: str) -> str: return "False" # optimized away - if bcast_container_types is None: - bcast_container_types = () - - if np.ndarray in bcast_container_types and bcast_obj_array: + if np.ndarray in container_types_bcast_across and bcasts_across_obj_array: raise ValueError("If numpy.ndarray is part of bcast_container_types, " "bcast_obj_array must be False.") numpy_check_types: list[type] = [NumpyObjectArray, ComplainingNumpyNonObjectArray] - bcast_container_types = tuple( + container_types_bcast_across = tuple( new_ct - for old_ct in bcast_container_types + for old_ct in container_types_bcast_across for new_ct in (numpy_check_types if old_ct is np.ndarray @@ -334,7 +394,7 @@ def wrap(cls: Any) -> Any: if bcast_actx_array_type is None: if cls_has_array_context_attr: - if bcast_number: + if number_bcasts_across: bcast_actx_array_type = cls_has_array_context_attr else: bcast_actx_array_type = False @@ -409,14 +469,14 @@ def is_numpy_array(arg): """) gen("") - if bcast_container_types: - for i, bct in enumerate(bcast_container_types): + if container_types_bcast_across: + for i, bct in enumerate(container_types_bcast_across): gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") - outer_bcast_type_names = tuple( - f"_bctype{i}" for i in range(len(bcast_container_types))) - if bcast_number: - outer_bcast_type_names += ("Number",) + container_type_names_bcast_across = tuple( + f"_bctype{i}" for i in range(len(container_types_bcast_across))) + if number_bcasts_across: + container_type_names_bcast_across += ("Number",) def same_key(k1: T, k2: T) -> T: assert k1 == k2 @@ -428,9 +488,14 @@ def tup_str(t: Tuple[str, ...]) -> str: else: return "({},)".format(", ".join(t)) - gen(f"cls._outer_bcast_types = {tup_str(outer_bcast_type_names)}") + gen(f"cls._outer_bcast_types = {tup_str(container_type_names_bcast_across)}") + gen("cls._container_types_bcast_across = " + f"{tup_str(container_type_names_bcast_across)}") + gen(f"cls._bcast_numpy_array = {bcast_numpy_array}") - gen(f"cls._bcast_obj_array = {bcast_obj_array}") + + gen(f"cls._bcast_obj_array = {bcasts_across_obj_array}") + gen(f"cls._bcasts_across_obj_array = {bcasts_across_obj_array}") gen("") # {{{ unary operators @@ -535,9 +600,9 @@ def {fname}(arg1): result[i] = {op_str.format("arg1", "arg2[i]")} return result - if {bool(outer_bcast_type_names)}: # optimized away + if {bool(container_type_names_bcast_across)}: # optimized away if isinstance(arg2, - {tup_str(outer_bcast_type_names + {tup_str(container_type_names_bcast_across + bcast_actx_ary_types)}): if __debug__: if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): @@ -584,9 +649,9 @@ def {fname}(arg2, arg1): for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result - if {bool(outer_bcast_type_names)}: # optimized away + if {bool(container_type_names_bcast_across)}: # optimized away if isinstance(arg1, - {tup_str(outer_bcast_type_names + {tup_str(container_type_names_bcast_across + bcast_actx_ary_types)}): if __debug__: if isinstance(arg1, diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 107539b4..47d83903 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -117,7 +117,7 @@ def _acf(): # {{{ stand-in DOFArray implementation @with_container_arithmetic( - bcast_obj_array=True, + bcasts_across_obj_array=True, bitwise=True, rel_comparison=True, _cls_has_array_context_attr=True, @@ -208,7 +208,7 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: # {{{ nested containers -@with_container_arithmetic(bcast_obj_array=False, +@with_container_arithmetic(bcasts_across_obj_array=False, eq_comparison=False, rel_comparison=False, _cls_has_array_context_attr=True, _bcast_actx_array_type=False) @@ -231,7 +231,7 @@ def array_context(self): @with_container_arithmetic( - bcast_obj_array=False, + bcasts_across_obj_array=False, bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, @@ -1225,7 +1225,7 @@ def test_norm_ord_none(actx_factory, ndim): # {{{ test_actx_compile helpers -@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) @dataclass_array_container @dataclass(frozen=True) class Velocity2D: @@ -1355,7 +1355,7 @@ def test_container_equality(actx_factory): # {{{ test_no_leaf_array_type_broadcasting @with_container_arithmetic( - bcast_obj_array=True, + bcasts_across_obj_array=True, rel_comparison=True, _cls_has_array_context_attr=True, _bcast_actx_array_type=False) @@ -1459,7 +1459,7 @@ def equal(a, b): # {{{ test_array_container_with_numpy -@with_container_arithmetic(bcast_obj_array=True, rel_comparison=True) +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) @dataclass_array_container @dataclass(frozen=True) class ArrayContainerWithNumpy: