diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 4e0ba830..c40117e8 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -90,7 +90,6 @@ PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory, pytest_generate_tests_for_array_contexts, - pytest_generate_tests_for_pyopencl_array_context, ) from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag @@ -139,7 +138,6 @@ "multimapped_over_array_containers", "outer", "pytest_generate_tests_for_array_contexts", - "pytest_generate_tests_for_pyopencl_array_context", "rec_map_array_container", "rec_map_reduce_array_container", "rec_multimap_array_container", diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index de188cbd..6c326374 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -84,7 +84,7 @@ def __init__(self, queue: pyopencl.CommandQueue, allocator: Optional[pyopencl.tools.AllocatorBase] = None, wait_event_queue_length: Optional[int] = None, - force_device_scalars: bool = False) -> None: + force_device_scalars: Optional[bool] = None) -> None: r""" :arg wait_event_queue_length: The length of a queue of :class:`~pyopencl.Event` objects that are maintained by the @@ -105,21 +105,15 @@ def __init__(self, For now, *wait_event_queue_length* should be regarded as an experimental feature that may change or disappear at any minute. - - :arg force_device_scalars: if *True*, scalar results returned from - reductions in :attr:`ArrayContext.np` will be kept on the device. - If *False*, the equivalent of :meth:`~ArrayContext.freeze` and - :meth:`~ArrayContext.to_numpy` is applied to transfer the results - to the host. """ - if not force_device_scalars: - warn("Configuring the PyOpenCLArrayContext to return host scalars " - "from reductions is deprecated. " - "To configure the PyOpenCLArrayContext to return " - "device scalars, pass 'force_device_scalars=True' to the " - "constructor. " - "Support for returning host scalars will be removed in 2022.", - DeprecationWarning, stacklevel=2) + if force_device_scalars is not None: + warn( + "`force_device_scalars` is deprecated and will be removed in 2025.", + DeprecationWarning, stacklevel=2) + + if not force_device_scalars: + raise ValueError( + "Passing force_device_scalars=False is not allowed.") import pyopencl as cl import pyopencl.array as cl_array @@ -131,7 +125,6 @@ def __init__(self, if wait_event_queue_length is None: wait_event_queue_length = 10 - self._force_device_scalars = force_device_scalars self._wait_event_queue_length = wait_event_queue_length self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {} @@ -268,8 +261,7 @@ def call_loopy(self, t_unit, **kwargs): def clone(self): return type(self)(self.queue, self.allocator, - wait_event_queue_length=self._wait_event_queue_length, - force_device_scalars=self._force_device_scalars) + wait_event_queue_length=self._wait_event_queue_length) # }}} diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 848870a9..ac792452 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -169,15 +169,11 @@ def stack(self, arrays, axis=0): # {{{ linear algebra def vdot(self, x, y, dtype=None): - result = rec_multimap_reduce_array_container( + return rec_multimap_reduce_array_container( sum, partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue), x, y) - if not self._array_context._force_device_scalars: - result = result.get()[()] - return result - # }}} # {{{ logic functions @@ -190,15 +186,11 @@ def _all(ary): return np.int8(all([ary])) return ary.all(queue=queue) - result = rec_map_reduce_array_container( + return rec_map_reduce_array_container( partial(reduce, partial(cl_array.minimum, queue=queue)), _all, a) - if not self._array_context._force_device_scalars: - result = result.get()[()] - return result - def any(self, a): queue = self._array_context.queue @@ -207,15 +199,11 @@ def _any(ary): return np.int8(any([ary])) return ary.any(queue=queue) - result = rec_map_reduce_array_container( + return rec_map_reduce_array_container( partial(reduce, partial(cl_array.maximum, queue=queue)), _any, a) - if not self._array_context._force_device_scalars: - result = result.get()[()] - return result - def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array: actx = self._array_context queue = actx.queue @@ -251,11 +239,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array: in zip(serialized_x, serialized_y)], true_ary) - result = rec_equal(a, b) - if not self._array_context._force_device_scalars: - result = result.get()[()] - - return result + return rec_equal(a, b) # FIXME: This should be documentation, not a comment. # These are here mainly because some arrays may choose to interpret @@ -305,11 +289,7 @@ def _rec_sum(ary): return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue) - result = rec_map_reduce_array_container(sum, _rec_sum, a) - - if not self._array_context._force_device_scalars: - result = result.get()[()] - return result + return rec_map_reduce_array_container(sum, _rec_sum, a) def maximum(self, x, y): return rec_multimap_array_container( @@ -327,15 +307,11 @@ def _rec_max(ary): raise NotImplementedError(f"Max. over '{axis}' axes not supported.") return cl_array.max(ary, queue=queue) - result = rec_map_reduce_array_container( + return rec_map_reduce_array_container( partial(reduce, partial(cl_array.maximum, queue=queue)), _rec_max, a) - if not self._array_context._force_device_scalars: - result = result.get()[()] - return result - max = amax def minimum(self, x, y): @@ -354,15 +330,11 @@ def _rec_min(ary): raise NotImplementedError(f"Min. over '{axis}' axes not supported.") return cl_array.min(ary, queue=queue) - result = rec_map_reduce_array_container( + return rec_map_reduce_array_container( partial(reduce, partial(cl_array.minimum, queue=queue)), _rec_min, a) - if not self._array_context._force_device_scalars: - result = result.get()[()] - return result - min = amin def absolute(self, a): diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 088c7e3e..c778154d 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -5,7 +5,6 @@ .. autoclass:: PytestPyOpenCLArrayContextFactory .. autofunction:: pytest_generate_tests_for_array_contexts -.. autofunction:: pytest_generate_tests_for_pyopencl_array_context """ __copyright__ = """ @@ -88,7 +87,16 @@ def get_command_queue(self): class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory): - force_device_scalars = True + # Deprecated, remove in 2025. + _force_device_scalars = True + + @property + def force_device_scalars(self): + from warnings import warn + warn( + "force_device_scalars is deprecated and will be removed in 2025.", + DeprecationWarning, stacklevel=2) + return self._force_device_scalars @property def actx_class(self): @@ -117,8 +125,7 @@ def __call__(self): return self.actx_class( queue, - allocator=alloc, - force_device_scalars=self.force_device_scalars) + allocator=alloc) def __str__(self): return (f"<{self.actx_class.__name__} " @@ -126,11 +133,6 @@ def __str__(self): f"on '{self.device.platform.name.strip()}'>>") -class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars( - _PytestPyOpenCLArrayContextFactoryWithClass): - force_device_scalars = False - - class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory): @classmethod def is_available(cls) -> bool: @@ -245,8 +247,6 @@ def __str__(self): _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, - "pyopencl-deprecated": - _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars, "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, @@ -285,10 +285,7 @@ def pytest_generate_tests_for_array_contexts( "pyopencl", ]) - to use the :mod:`pyopencl`-based array context. For :mod:`pyopencl`-based - contexts :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl` is used - as a backend, which allows specifying the ``PYOPENCL_TEST`` environment - variable for device selection. + to use the :mod:`pyopencl`-based array context. The environment variable ``ARRAYCONTEXT_TEST`` can also be used to overwrite any chosen implementations through *factories*. This is a @@ -296,11 +293,7 @@ def pytest_generate_tests_for_array_contexts( Current supported implementations include: - * ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext` - with ``force_device_scalars=True``. - * ``"pyopencl-deprecated"``, which creates a - :class:`~arraycontext.PyOpenCLArrayContext` with - ``force_device_scalars=False``. + * ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`. * ``"pytato-pyopencl"``, which creates a :class:`~arraycontext.PytatoPyOpenCLArrayContext`. @@ -404,45 +397,6 @@ def inner(metafunc): return inner - -def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None: - """Parametrize tests for pytest to use a - :class:`~arraycontext.PyOpenCLArrayContext`. - - Performs device enumeration analogously to - :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl`. - - Using the line: - - .. code-block:: python - - from arraycontext import ( - pytest_generate_tests_for_pyopencl_array_context - as pytest_generate_tests) - - in your pytest test scripts allows you to use the argument ``actx_factory``, - in your test functions, and they will automatically be - run once for each OpenCL device/platform in the system, as appropriate, - with an argument-less function that returns an - :class:`~arraycontext.ArrayContext` when called. - - It also allows you to specify the ``PYOPENCL_TEST`` environment variable - for device selection. - """ - - from warnings import warn - warn("pytest_generate_tests_for_pyopencl_array_context is deprecated. " - "Use 'pytest_generate_tests = " - "arraycontext.pytest_generate_tests_for_array_contexts" - "([\"pyopencl-deprecated\"])' instead. " - "pytest_generate_tests_for_pyopencl_array_context will stop working " - "in 2022.", - DeprecationWarning, stacklevel=2) - - pytest_generate_tests_for_array_contexts([ - "pyopencl-deprecated", - ], factory_arg_name="actx_factory")(metafunc) - # }}} diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 47d83903..7bea0dc4 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -665,12 +665,7 @@ def test_reductions_same_as_numpy(actx_factory, op): actx_red = getattr(actx.np, op)(actx.from_numpy(ary)) actx_red = actx.to_numpy(actx_red) - from numbers import Number - - if isinstance(actx, PyOpenCLArrayContext) and (not actx._force_device_scalars): - assert isinstance(actx_red, Number) - else: - assert actx_red.shape == () + assert actx_red.shape == () assert np.allclose(np_red, actx_red)