Skip to content

Commit

Permalink
forbid force_device_scalars=False
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored and inducer committed Sep 4, 2024
1 parent 66d4663 commit 867273a
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 120 deletions.
2 changes: 0 additions & 2 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@
PytestArrayContextFactory,
PytestPyOpenCLArrayContextFactory,
pytest_generate_tests_for_array_contexts,
pytest_generate_tests_for_pyopencl_array_context,
)
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag

Expand Down Expand Up @@ -139,7 +138,6 @@
"multimapped_over_array_containers",
"outer",
"pytest_generate_tests_for_array_contexts",
"pytest_generate_tests_for_pyopencl_array_context",
"rec_map_array_container",
"rec_map_reduce_array_container",
"rec_multimap_array_container",
Expand Down
28 changes: 10 additions & 18 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self,
queue: pyopencl.CommandQueue,
allocator: Optional[pyopencl.tools.AllocatorBase] = None,
wait_event_queue_length: Optional[int] = None,
force_device_scalars: bool = False) -> None:
force_device_scalars: Optional[bool] = None) -> None:
r"""
:arg wait_event_queue_length: The length of a queue of
:class:`~pyopencl.Event` objects that are maintained by the
Expand All @@ -105,21 +105,15 @@ def __init__(self,
For now, *wait_event_queue_length* should be regarded as an
experimental feature that may change or disappear at any minute.
:arg force_device_scalars: if *True*, scalar results returned from
reductions in :attr:`ArrayContext.np` will be kept on the device.
If *False*, the equivalent of :meth:`~ArrayContext.freeze` and
:meth:`~ArrayContext.to_numpy` is applied to transfer the results
to the host.
"""
if not force_device_scalars:
warn("Configuring the PyOpenCLArrayContext to return host scalars "
"from reductions is deprecated. "
"To configure the PyOpenCLArrayContext to return "
"device scalars, pass 'force_device_scalars=True' to the "
"constructor. "
"Support for returning host scalars will be removed in 2022.",
DeprecationWarning, stacklevel=2)
if force_device_scalars is not None:
warn(
"`force_device_scalars` is deprecated and will be removed in 2025.",
DeprecationWarning, stacklevel=2)

if not force_device_scalars:
raise ValueError(
"Passing force_device_scalars=False is not allowed.")

import pyopencl as cl
import pyopencl.array as cl_array
Expand All @@ -131,7 +125,6 @@ def __init__(self,
if wait_event_queue_length is None:
wait_event_queue_length = 10

self._force_device_scalars = force_device_scalars
self._wait_event_queue_length = wait_event_queue_length
self._kernel_name_to_wait_event_queue: Dict[str, List[cl.Event]] = {}

Expand Down Expand Up @@ -268,8 +261,7 @@ def call_loopy(self, t_unit, **kwargs):

def clone(self):
return type(self)(self.queue, self.allocator,
wait_event_queue_length=self._wait_event_queue_length,
force_device_scalars=self._force_device_scalars)
wait_event_queue_length=self._wait_event_queue_length)

# }}}

Expand Down
42 changes: 7 additions & 35 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,11 @@ def stack(self, arrays, axis=0):
# {{{ linear algebra

def vdot(self, x, y, dtype=None):
result = rec_multimap_reduce_array_container(
return rec_multimap_reduce_array_container(
sum,
partial(cl_array.vdot, dtype=dtype, queue=self._array_context.queue),
x, y)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result

# }}}

# {{{ logic functions
Expand All @@ -190,15 +186,11 @@ def _all(ary):
return np.int8(all([ary]))
return ary.all(queue=queue)

result = rec_map_reduce_array_container(
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)),
_all,
a)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result

def any(self, a):
queue = self._array_context.queue

Expand All @@ -207,15 +199,11 @@ def _any(ary):
return np.int8(any([ary]))
return ary.any(queue=queue)

result = rec_map_reduce_array_container(
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
_any,
a)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result

def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
actx = self._array_context
queue = actx.queue
Expand Down Expand Up @@ -251,11 +239,7 @@ def rec_equal(x: ArrayOrContainer, y: ArrayOrContainer) -> cl_array.Array:
in zip(serialized_x, serialized_y)],
true_ary)

result = rec_equal(a, b)
if not self._array_context._force_device_scalars:
result = result.get()[()]

return result
return rec_equal(a, b)

# FIXME: This should be documentation, not a comment.
# These are here mainly because some arrays may choose to interpret
Expand Down Expand Up @@ -305,11 +289,7 @@ def _rec_sum(ary):

return cl_array.sum(ary, dtype=dtype, queue=self._array_context.queue)

result = rec_map_reduce_array_container(sum, _rec_sum, a)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result
return rec_map_reduce_array_container(sum, _rec_sum, a)

def maximum(self, x, y):
return rec_multimap_array_container(
Expand All @@ -327,15 +307,11 @@ def _rec_max(ary):
raise NotImplementedError(f"Max. over '{axis}' axes not supported.")
return cl_array.max(ary, queue=queue)

result = rec_map_reduce_array_container(
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.maximum, queue=queue)),
_rec_max,
a)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result

max = amax

def minimum(self, x, y):
Expand All @@ -354,15 +330,11 @@ def _rec_min(ary):
raise NotImplementedError(f"Min. over '{axis}' axes not supported.")
return cl_array.min(ary, queue=queue)

result = rec_map_reduce_array_container(
return rec_map_reduce_array_container(
partial(reduce, partial(cl_array.minimum, queue=queue)),
_rec_min,
a)

if not self._array_context._force_device_scalars:
result = result.get()[()]
return result

min = amin

def absolute(self, a):
Expand Down
72 changes: 13 additions & 59 deletions arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
.. autoclass:: PytestPyOpenCLArrayContextFactory
.. autofunction:: pytest_generate_tests_for_array_contexts
.. autofunction:: pytest_generate_tests_for_pyopencl_array_context
"""

__copyright__ = """
Expand Down Expand Up @@ -88,7 +87,16 @@ def get_command_queue(self):


class _PytestPyOpenCLArrayContextFactoryWithClass(PytestPyOpenCLArrayContextFactory):
force_device_scalars = True
# Deprecated, remove in 2025.
_force_device_scalars = True

@property
def force_device_scalars(self):
from warnings import warn
warn(
"force_device_scalars is deprecated and will be removed in 2025.",
DeprecationWarning, stacklevel=2)
return self._force_device_scalars

@property
def actx_class(self):
Expand Down Expand Up @@ -117,20 +125,14 @@ def __call__(self):

return self.actx_class(
queue,
allocator=alloc,
force_device_scalars=self.force_device_scalars)
allocator=alloc)

def __str__(self):
return (f"<{self.actx_class.__name__} "
f"for <pyopencl.Device '{self.device.name.strip()}' "
f"on '{self.device.platform.name.strip()}'>>")


class _PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars(
_PytestPyOpenCLArrayContextFactoryWithClass):
force_device_scalars = False


class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory):
@classmethod
def is_available(cls) -> bool:
Expand Down Expand Up @@ -245,8 +247,6 @@ def __str__(self):
_ARRAY_CONTEXT_FACTORY_REGISTRY: \
Dict[str, Type[PytestArrayContextFactory]] = {
"pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass,
"pyopencl-deprecated":
_PytestPyOpenCLArrayContextFactoryWithClassAndHostScalars,
"pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory,
"pytato:jax": _PytestPytatoJaxArrayContextFactory,
"eagerjax": _PytestEagerJaxArrayContextFactory,
Expand Down Expand Up @@ -285,22 +285,15 @@ def pytest_generate_tests_for_array_contexts(
"pyopencl",
])
to use the :mod:`pyopencl`-based array context. For :mod:`pyopencl`-based
contexts :func:`pyopencl.tools.pytest_generate_tests_for_pyopencl` is used
as a backend, which allows specifying the ``PYOPENCL_TEST`` environment
variable for device selection.
to use the :mod:`pyopencl`-based array context.
The environment variable ``ARRAYCONTEXT_TEST`` can also be used to
overwrite any chosen implementations through *factories*. This is a
comma-separated list of known array contexts.
Current supported implementations include:
* ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`
with ``force_device_scalars=True``.
* ``"pyopencl-deprecated"``, which creates a
:class:`~arraycontext.PyOpenCLArrayContext` with
``force_device_scalars=False``.
* ``"pyopencl"``, which creates a :class:`~arraycontext.PyOpenCLArrayContext`.
* ``"pytato-pyopencl"``, which creates a
:class:`~arraycontext.PytatoPyOpenCLArrayContext`.
Expand Down Expand Up @@ -404,45 +397,6 @@ def inner(metafunc):

return inner


def pytest_generate_tests_for_pyopencl_array_context(metafunc) -> None:
"""Parametrize tests for pytest to use a
:class:`~arraycontext.PyOpenCLArrayContext`.
Performs device enumeration analogously to
:func:`pyopencl.tools.pytest_generate_tests_for_pyopencl`.
Using the line:
.. code-block:: python
from arraycontext import (
pytest_generate_tests_for_pyopencl_array_context
as pytest_generate_tests)
in your pytest test scripts allows you to use the argument ``actx_factory``,
in your test functions, and they will automatically be
run once for each OpenCL device/platform in the system, as appropriate,
with an argument-less function that returns an
:class:`~arraycontext.ArrayContext` when called.
It also allows you to specify the ``PYOPENCL_TEST`` environment variable
for device selection.
"""

from warnings import warn
warn("pytest_generate_tests_for_pyopencl_array_context is deprecated. "
"Use 'pytest_generate_tests = "
"arraycontext.pytest_generate_tests_for_array_contexts"
"([\"pyopencl-deprecated\"])' instead. "
"pytest_generate_tests_for_pyopencl_array_context will stop working "
"in 2022.",
DeprecationWarning, stacklevel=2)

pytest_generate_tests_for_array_contexts([
"pyopencl-deprecated",
], factory_arg_name="actx_factory")(metafunc)

# }}}


Expand Down
7 changes: 1 addition & 6 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,12 +665,7 @@ def test_reductions_same_as_numpy(actx_factory, op):
actx_red = getattr(actx.np, op)(actx.from_numpy(ary))
actx_red = actx.to_numpy(actx_red)

from numbers import Number

if isinstance(actx, PyOpenCLArrayContext) and (not actx._force_device_scalars):
assert isinstance(actx_red, Number)
else:
assert actx_red.shape == ()
assert actx_red.shape == ()

assert np.allclose(np_red, actx_red)

Expand Down

0 comments on commit 867273a

Please sign in to comment.