From 03229d3d851e7377b7b05800cd3cad7cccb092e3 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 Jul 2024 16:57:52 -0500 Subject: [PATCH] Rework dataclass array container arithmetic - Deprecate automatic broadcasting of array context arrays - Introduce Bcast as an object to represent broadcast rules - Warn about uses of numpy array broadcasting, deprecated earlier - Clarify documentation, warning wording --- arraycontext/__init__.py | 17 +- arraycontext/container/arithmetic.py | 381 +++++++++++++++++++++++---- arraycontext/container/traversal.py | 4 +- test/test_arraycontext.py | 82 +++++- 4 files changed, 423 insertions(+), 61 deletions(-) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index e8e6e9f3..e8e45ea9 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -43,7 +43,15 @@ register_multivector_as_array_container, serialize_container, ) -from .container.arithmetic import with_container_arithmetic +from .container.arithmetic import ( + Bcast, + Bcast1, + Bcast2, + Bcast3, + BcastNLevels, + BcastUntilActxArray, + with_container_arithmetic, +) from .container.dataclass import dataclass_array_container from .container.traversal import ( flat_size_and_dtype, @@ -103,6 +111,12 @@ "ArrayOrContainerOrScalarT", "ArrayOrContainerT", "ArrayT", + "Bcast", + "Bcast1", + "Bcast2", + "Bcast3", + "BcastNLevels", + "BcastUntilActxArray", "CommonSubexpressionTag", "EagerJAXArrayContext", "ElementwiseMapKernelTag", @@ -151,7 +165,6 @@ "unflatten", "with_array_context", "with_container_arithmetic", - "with_container_arithmetic" ) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index dbfdd5a6..212c7cda 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -2,12 +2,26 @@ from __future__ import annotations -""" +__doc__ = """ .. currentmodule:: arraycontext + .. autofunction:: with_container_arithmetic -""" +.. autoclass:: Bcast +.. autoclass:: BcastNLevels +.. autoclass:: BcastUntilActxArray -import enum +.. function:: Bcast1 + + Like :class:`BcastNLevels` with *nlevels* set to 1. + +.. function:: Bcast2 + + Like :class:`BcastNLevels` with *nlevels* set to 2. + +.. function:: Bcast3 + + Like :class:`BcastNLevels` with *nlevels* set to 3. +""" __copyright__ = """ @@ -34,10 +48,18 @@ THE SOFTWARE. """ -from typing import Any, Callable, Optional, Tuple, TypeVar, Union +import enum +from abc import ABC, abstractmethod +from dataclasses import FrozenInstanceError +from functools import partial +from numbers import Number +from typing import Any, Callable, ClassVar, Optional, Tuple, TypeVar, Union +from warnings import warn import numpy as np +from arraycontext.context import ArrayContext, ArrayOrContainer + # {{{ with_container_arithmetic @@ -99,8 +121,8 @@ def _format_unary_op_str(op_str: str, arg1: Union[Tuple[str, ...], str]) -> str: def _format_binary_op_str(op_str: str, - arg1: Union[Tuple[str, ...], str], - arg2: Union[Tuple[str, ...], str]) -> str: + arg1: Union[Tuple[str, str], str], + arg2: Union[Tuple[str, str], str]) -> str: if isinstance(arg1, tuple) and isinstance(arg2, tuple): import sys if sys.version_info >= (3, 10): @@ -127,6 +149,134 @@ def _format_binary_op_str(op_str: str, return op_str.format(arg1, arg2) +class NumpyObjectArrayMetaclass(type): + def __instancecheck__(cls, instance: Any) -> bool: + return isinstance(instance, np.ndarray) and instance.dtype == object + + +class NumpyObjectArray(metaclass=NumpyObjectArrayMetaclass): + pass + + +class Bcast: + """ + A wrapper object to force arithmetic generated by :func:`with_container_arithmetic` + to broadcast *arg* across a container (with the container as the 'outer' structure). + Since array containers are often nested in complex ways, different subclasses + implement different rules on how broadcasting interacts with the hierarchy, + with :class:`BcastNLevels` and :class:`BcastUntilActxArray` representing + the most common. + """ + arg: ArrayOrContainer + + # Accessing this attribute is cheaper than isinstance, so use that + # to distinguish _BcastWithNextOperand and _BcastWithoutNextOperand. + _with_next_operand: ClassVar[bool] + + def __init__(self, arg: ArrayOrContainer) -> None: + object.__setattr__(self, "arg", arg) + + def __setattr__(self, name: str, value: Any) -> None: + raise FrozenInstanceError() + + def __delattr__(self, name: str) -> None: + raise FrozenInstanceError() + + +class _BcastWithNextOperand(Bcast, ABC): + """ + A :class:`Bcast` object that gets to see who the next operand will be, in + order to decide whether wrapping the child in :class:`Bcast` is still necessary. + This is much more flexible, but also considerably more expensive, than + :class:`_BcastWithoutNextOperand`. + """ + + _with_next_operand = True + + # purposefully undocumented + @abstractmethod + def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer: + ... + + +class _BcastWithoutNextOperand(Bcast, ABC): + """ + A :class:`Bcast` object that does not get to see who the next operand will be. + """ + _with_next_operand = False + + # purposefully undocumented + @abstractmethod + def _rewrap(self) -> ArrayOrContainer: + ... + + +class BcastNLevels(_BcastWithoutNextOperand): + """ + A broadcasting rule that lets *arg* broadcast against *nlevels* "levels" of + array containers. Use :func:`Bcast1`, :func:`Bcast2`, :func:`Bcast3` as + convenient aliases for the common cases. + + Usage example:: + + container + Bcast2(actx_array) + + .. note:: + + :mod:`numpy` object arrays do not count against the number of levels. + + .. automethod:: __init__ + """ + nlevels: int + + def __init__(self, nlevels: int, arg: ArrayOrContainer) -> None: + if nlevels < 1: + raise ValueError("nlevels is expected to be one or greater.") + + super().__init__(arg) + object.__setattr__(self, "nlevels", nlevels) + + def _rewrap(self) -> ArrayOrContainer: + if self.nlevels == 1: + return self.arg + else: + return BcastNLevels(self.nlevels-1, self.arg) + + +Bcast1 = partial(BcastNLevels, 1) +Bcast2 = partial(BcastNLevels, 2) +Bcast3 = partial(BcastNLevels, 3) + + +class BcastUntilActxArray(_BcastWithNextOperand): + """ + A broadcast rule that broadcasts *arg* across array containers until + the 'opposite' operand is one of the :attr:`~arraycontext.ArrayContext.array_types` + of *actx*, or a :class:`~numbers.Number`. + + Suggested usage pattern:: + + bcast = functools.partial(BcastUntilActxArray, actx) + + container + bcast(actx_array) + + .. automethod:: __init__ + """ + actx: ArrayContext + + def __init__(self, + actx: ArrayContext, + arg: ArrayOrContainer) -> None: + super().__init__(arg) + object.__setattr__(self, "actx", actx) + + def _rewrap(self, other_operand: ArrayOrContainer) -> ArrayOrContainer: + if isinstance(other_operand, (*self.actx.array_types, Number)): + return self.arg + else: + return BcastUntilActxArray(self.actx, self.arg) + + def with_container_arithmetic( *, bcast_number: bool = True, @@ -146,22 +296,16 @@ def with_container_arithmetic( :arg bcast_number: If *True*, numbers broadcast over the container (with the container as the 'outer' structure). - :arg _bcast_actx_array_type: If *True*, instances of base array types of the - container's array context are broadcasted over the container. Can be - *True* only if the container has *_cls_has_array_context_attr* set. - Defaulted to *bcast_number* if *_cls_has_array_context_attr* is set, - else *False*. - :arg bcast_obj_array: If *True*, :mod:`numpy` object arrays broadcast over - the container. (with the container as the 'inner' structure) - :arg bcast_numpy_array: If *True*, any :class:`numpy.ndarray` will broadcast - over the container. (with the container as the 'inner' structure) - If this is set to *True*, *bcast_obj_array* must also be *True*. + :arg bcast_obj_array: If *True*, this container will be broadcast + across :mod:`numpy` object arrays + (with the object array as the 'outer' structure). + Add :class:`numpy.ndarray` to *bcast_container_types* to achieve + the 'reverse' broadcasting. :arg bcast_container_types: A sequence of container types that will broadcast - over this container (with this container as the 'outer' structure). + across this container, with this container as the 'outer' structure. :class:`numpy.ndarray` is permitted to be part of this sequence to - indicate that, in such broadcasting situations, this container should - be the 'outer' structure. In this case, *bcast_obj_array* - (and consequently *bcast_numpy_array*) must be *False*. + indicate that object arrays (and *only* object arrays) will be broadcasat. + In this case, *bcast_obj_array* must be *False*. :arg arithmetic: Implement the conventional arithmetic operators, including ``**``, :func:`divmod`, and ``//``. Also includes ``+`` and ``-`` as well as :func:`abs`. @@ -181,6 +325,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__` Each operator class also includes the "reverse" operators if applicable. + .. note:: + + For the generated binary arithmetic operators, if certain types + should be broadcast over the container (with the container as the + 'outer' structure) but are not handled in this way by their types, + you may wrap them in one of the :class:`Bcast` variants to achieve + the desired semantics. + .. note:: To generate the code implementing the operators, this function relies on @@ -203,6 +355,35 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): should nest "outside" :func:dataclass_array_container`. """ + # Hard-won design lessons: + # + # - Anything that special-cases np.ndarray by type is broken by design because: + # - np.ndarray is an array context array. + # - numpy object arrays can be array containers. + # Using NumpyObjectArray and NumpyNonObjectArray *may* be better? + # They're new, so there is no operational experience with them. + # + # - Broadcast rules are hard to change once established, particularly + # because one cannot grep for their use. + # + # Possible advantages of the "Bcast" broadcast-rule-as-object design: + # + # - If one rule does not fit the user's need, they can straightforwardly use + # another. + # + # - It's straightforward to find where certain broadcast rules are used. + # + # - The broadcast rule can contain more state. For example, it's now easy + # for the rule to know what array context should be used to determine + # actx array types. + # + # Possible downsides of the "Bcast" broadcast-rule-as-object design: + # + # - User code is a bit more wordy. + # + # - Rewrapping has the potential to be costly, especially in + # _with_next_operand mode. + # {{{ handle inputs if bcast_obj_array is None: @@ -212,9 +393,8 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): raise TypeError("rel_comparison must be specified") if bcast_numpy_array: - from warnings import warn warn("'bcast_numpy_array=True' is deprecated and will be unsupported" - " from December 2021", DeprecationWarning, stacklevel=2) + " from 2025.", DeprecationWarning, stacklevel=2) if _bcast_actx_array_type: raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" @@ -231,7 +411,7 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if bcast_numpy_array: def numpy_pred(name: str) -> str: - return f"isinstance({name}, np.ndarray)" + return f"is_numpy_array({name})" elif bcast_obj_array: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" @@ -241,11 +421,14 @@ def numpy_pred(name: str) -> str: if bcast_container_types is None: bcast_container_types = () - bcast_container_types_count = len(bcast_container_types) if np.ndarray in bcast_container_types and bcast_obj_array: raise ValueError("If numpy.ndarray is part of bcast_container_types, " "bcast_obj_array must be False.") + bcast_container_types = tuple( + NumpyObjectArray if ct is np.ndarray else ct + for ct in bcast_container_types + ) desired_op_classes = set() if arithmetic: @@ -264,10 +447,15 @@ def numpy_pred(name: str) -> str: # }}} def wrap(cls: Any) -> Any: - cls_has_array_context_attr: bool | None = \ - _cls_has_array_context_attr - bcast_actx_array_type: bool | None = \ - _bcast_actx_array_type + if not hasattr(cls, "__array_ufunc__"): + warn(f"{cls} does not have __array_ufunc__ set. " + "This will cause numpy to attempt broadcasting, in a way that " + "is likely undesired. " + f"To avoid this, set __array_ufunc__ = None in {cls}.", + stacklevel=2) + + cls_has_array_context_attr: bool | None = _cls_has_array_context_attr + bcast_actx_array_type: bool | None = _bcast_actx_array_type if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): @@ -275,8 +463,8 @@ def wrap(cls: Any) -> Any: f"{cls} has an 'array_context' attribute, but it does not " "set '_cls_has_array_context_attr' to *True* when calling " "with_container_arithmetic. This is being interpreted " - "as 'array_context' being permitted to fail with an exception, " - "which is no longer allowed. " + "as '.array_context' being permitted to fail " + "with an exception, which is no longer allowed. " f"If {cls.__name__}.array_context will not fail, pass " "'_cls_has_array_context_attr=True'. " "If you do not want container arithmetic to make " @@ -294,6 +482,28 @@ def wrap(cls: Any) -> Any: raise TypeError("_bcast_actx_array_type can be True only if " "_cls_has_array_context_attr is set.") + if bcast_actx_array_type: + if _bcast_actx_array_type: + warn( + f"Broadcasting array context array types across {cls} " + "has been explicitly " + "enabled. As of 2025, this will stop working. " + "Express these operations using arraycontext.Bcast variants " + "instead." + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False. ", + DeprecationWarning, stacklevel=2) + else: + warn( + f"Broadcasting array context array types across {cls} " + "has been implicitly " + "enabled. As of 2025, this will no longer work. " + "Express these operations using arraycontext.Bcast variants " + "instead." + "To opt out now (and avoid this warning), " + "pass _bcast_actx_array_type=False.", + DeprecationWarning, stacklevel=2) + if (not hasattr(cls, "_serialize_init_arrays_code") or not hasattr(cls, "_deserialize_init_arrays_code")): raise TypeError(f"class '{cls.__name__}' must provide serialization " @@ -304,10 +514,10 @@ def wrap(cls: Any) -> Any: from pytools.codegen import CodeGenerator, Indentation gen = CodeGenerator() - gen(""" + gen(f""" from numbers import Number import numpy as np - from arraycontext import ArrayContainer + from arraycontext import ArrayContainer, Bcast from warnings import warn def _raise_if_actx_none(actx): @@ -315,6 +525,25 @@ def _raise_if_actx_none(actx): raise ValueError("array containers with frozen arrays " "cannot be operated upon") return actx + + def is_numpy_array(arg): + if isinstance(arg, np.ndarray): + if arg.dtype != "O": + warn("Operand is a non-object numpy array, " + "and the broadcasting behavior of this array container " + "({cls}) " + "is influenced by this because of its use of " + "the deprecated bcast_numpy_array. This broadcasting " + "behavior will change in 2025. If you would like the " + "broadcasting behavior to stay the same, make sure " + "to convert the passed numpy array to an " + "object array, or use arraycontext.Bcast to achieve " + "the desired broadcasting semantics.", + DeprecationWarning, stacklevel=2) + return True + else: + return False + """) gen("") @@ -323,7 +552,7 @@ def _raise_if_actx_none(actx): gen(f"from {bct.__module__} import {bct.__qualname__} as _bctype{i}") gen("") outer_bcast_type_names = tuple([ - f"_bctype{i}" for i in range(bcast_container_types_count) + f"_bctype{i}" for i in range(len(bcast_container_types)) ]) if bcast_number: outer_bcast_type_names += ("Number",) @@ -384,8 +613,6 @@ def {fname}(arg1): continue - # {{{ "forward" binary operators - zip_init_args = cls._deserialize_init_arrays_code("arg1", { same_key(key_arg1, key_arg2): _format_binary_op_str(op_str, expr_arg1, expr_arg2) @@ -393,11 +620,45 @@ def {fname}(arg1): cls._serialize_init_arrays_code("arg1").items(), cls._serialize_init_arrays_code("arg2").items()) }) - bcast_same_cls_init_args = cls._deserialize_init_arrays_code("arg1", { + bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", { key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") for key_arg1, expr_arg1 in cls._serialize_init_arrays_code("arg1").items() }) + bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", { + key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2) + for key_arg2, expr_arg2 in + cls._serialize_init_arrays_code("arg2").items() + }) + + def get_operand(arg: Union[tuple[str, str], str]) -> str: + if isinstance(arg, tuple): + entry, _container = arg + return entry + else: + return arg + + bcast_init_args_arg1_is_outer_with_rewrap = \ + cls._deserialize_init_arrays_code("arg1", { + key_arg1: + _format_binary_op_str( + op_str, expr_arg1, + f"arg2._rewrap({get_operand(expr_arg1)})") + for key_arg1, expr_arg1 in + cls._serialize_init_arrays_code("arg1").items() + }) + bcast_init_args_arg2_is_outer_with_rewrap = \ + cls._deserialize_init_arrays_code("arg2", { + key_arg2: + _format_binary_op_str( + op_str, + f"arg1._rewrap({get_operand(expr_arg2)})", + expr_arg2) + for key_arg2, expr_arg2 in + cls._serialize_init_arrays_code("arg2").items() + }) + + # {{{ "forward" binary operators gen(f"def {fname}(arg1, arg2):") with Indentation(gen): @@ -424,7 +685,7 @@ def {fname}(arg1): if bcast_actx_array_type: if __debug__: - bcast_actx_ary_types = ( + bcast_actx_ary_types: tuple[str, ...] = ( "*_raise_if_actx_none(" "arg1.array_context).array_types",) else: @@ -444,7 +705,24 @@ def {fname}(arg1): if isinstance(arg2, {tup_str(outer_bcast_type_names + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) + if __debug__: + if isinstance(arg2, {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg2)}} is deprecated " + "and will no longer work in 2025. " + "Use arraycontext.Bcast to achieve the desired " + "broadcasting semantics.", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg1_is_outer}) + + if isinstance(arg2, Bcast): + if arg2._with_next_operand: + return cls({bcast_init_args_arg1_is_outer_with_rewrap}) + else: + arg2 = arg2._rewrap() + return cls({bcast_init_args_arg1_is_outer}) + return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -456,12 +734,6 @@ def {fname}(arg1): if reversible: fname = f"_{cls.__name__.lower()}_r{dunder_name}" - bcast_init_args = cls._deserialize_init_arrays_code("arg2", { - key_arg2: _format_binary_op_str( - op_str, "arg1", expr_arg2) - for key_arg2, expr_arg2 in - cls._serialize_init_arrays_code("arg2").items() - }) if bcast_actx_array_type: if __debug__: @@ -487,7 +759,26 @@ def {fname}(arg2, arg1): if isinstance(arg1, {tup_str(outer_bcast_type_names + bcast_actx_ary_types)}): - return cls({bcast_init_args}) + if __debug__: + if isinstance(arg1, + {tup_str(bcast_actx_ary_types)}): + warn("Broadcasting {cls} over array " + f"context array type {{type(arg1)}} " + "is deprecated " + "and will no longer work in 2025.", + "Use arraycontext.Bcast to achieve the " + "desired broadcasting semantics.", + DeprecationWarning, stacklevel=2) + + return cls({bcast_init_args_arg2_is_outer}) + + if isinstance(arg1, Bcast): + if arg1._with_next_operand: + return cls({bcast_init_args_arg2_is_outer_with_rewrap}) + else: + arg1 = arg1._rewrap() + return cls({bcast_init_args_arg2_is_outer}) + return NotImplemented cls.__r{dunder_name}__ = {fname}""") diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 31b3bcf5..100f0775 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -43,6 +43,8 @@ from __future__ import annotations +from arraycontext.container.arithmetic import NumpyObjectArray + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees @@ -964,7 +966,7 @@ def treat_as_scalar(x: Any) -> bool: return ( not isinstance(x, np.ndarray) # This condition is whether "ndarrays should broadcast inside x". - and np.ndarray not in x.__class__._outer_bcast_types) + and NumpyObjectArray not in x.__class__._outer_bcast_types) if treat_as_scalar(a) or treat_as_scalar(b): return a*b diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 94d7d748..e06b158d 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -22,6 +22,7 @@ import logging from dataclasses import dataclass +from functools import partial from typing import Union import numpy as np @@ -33,7 +34,11 @@ from arraycontext import ( ArrayContainer, ArrayContext, + Bcast1, + Bcast2, + BcastUntilActxArray, EagerJAXArrayContext, + NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, @@ -116,10 +121,10 @@ def _acf(): @with_container_arithmetic( bcast_obj_array=True, - bcast_numpy_array=True, bitwise=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) class DOFArray: def __init__(self, actx, data): if not (actx is None or isinstance(actx, ArrayContext)): @@ -207,7 +212,8 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: @with_container_arithmetic(bcast_obj_array=False, eq_comparison=False, rel_comparison=False, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainer: @@ -229,7 +235,8 @@ def array_context(self): bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class MyContainerDOFBcast: @@ -936,8 +943,6 @@ def test_container_arithmetic(actx_factory): def _check_allclose(f, arg1, arg2, atol=5.0e-14): assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol - from functools import partial - from arraycontext import rec_multimap_array_container for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: rec_multimap_array_container( @@ -1354,9 +1359,9 @@ def test_container_equality(actx_factory): @with_container_arithmetic( bcast_obj_array=True, - bcast_numpy_array=True, rel_comparison=True, - _cls_has_array_context_attr=True) + _cls_has_array_context_attr=True, + _bcast_actx_array_type=False) @dataclass_array_container @dataclass(frozen=True) class Foo: @@ -1373,10 +1378,27 @@ def test_leaf_array_type_broadcasting(actx_factory): # test support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() - foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, ))) + dof_ary = DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )) + foo = Foo(dof_ary) bar = foo + 4 - baz = foo + actx.from_numpy(4*np.ones((3, ))) - qux = actx.from_numpy(4*np.ones((3, ))) + foo + + bcast = partial(BcastUntilActxArray, actx) + + actx_ary = actx.from_numpy(4*np.ones((3, ))) + with pytest.raises(TypeError): + foo + actx_ary + + baz = foo + Bcast2(actx_ary) + qux = Bcast2(actx_ary) + foo + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(baz.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(qux.u[0])) + + baz = foo + bcast(actx_ary) + qux = bcast(actx_ary) + foo np.testing.assert_allclose(actx.to_numpy(bar.u[0]), actx.to_numpy(baz.u[0])) @@ -1384,6 +1406,29 @@ def test_leaf_array_type_broadcasting(actx_factory): np.testing.assert_allclose(actx.to_numpy(bar.u[0]), actx.to_numpy(qux.u[0])) + mc = MyContainer( + name="hi", + mass=dof_ary, + momentum=make_obj_array([dof_ary, dof_ary]), + enthalpy=dof_ary) + + with pytest.raises(TypeError): + mc_op = mc + actx_ary + + mom_op = mc.momentum + Bcast1(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0])) + + mc_op = mc + Bcast2(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + + mom_op = mc.momentum + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mom_op[0][0])) + + mc_op = mc + bcast(actx_ary) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.mass[0])) + np.testing.assert_allclose(45, actx.to_numpy(mc_op.momentum[1][0])) + def _actx_allows_scalar_broadcast(actx): if not isinstance(actx, PyOpenCLArrayContext): return True @@ -1394,8 +1439,11 @@ def _actx_allows_scalar_broadcast(actx): return cl.version.VERSION > (2021, 2, 5) if _actx_allows_scalar_broadcast(actx): - quux = foo + actx.from_numpy(np.array(4)) - quuz = actx.from_numpy(np.array(4)) + foo + with pytest.raises(TypeError): + foo + actx.from_numpy(np.array(4)) + + quuz = Bcast2(actx.from_numpy(np.array(4))) + foo + quux = foo + Bcast2(actx.from_numpy(np.array(4))) np.testing.assert_allclose(actx.to_numpy(bar.u[0]), actx.to_numpy(quux.u[0])) @@ -1403,6 +1451,14 @@ def _actx_allows_scalar_broadcast(actx): np.testing.assert_allclose(actx.to_numpy(bar.u[0]), actx.to_numpy(quuz.u[0])) + quuz = bcast(actx.from_numpy(np.array(4))) + foo + quux = foo + bcast(actx.from_numpy(np.array(4))) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quux.u[0])) + + np.testing.assert_allclose(actx.to_numpy(bar.u[0]), + actx.to_numpy(quuz.u[0])) # }}}