Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Bcast object-ified broacasting rules #280

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
14d4476
Fix some newly-flagged UP031 issues
inducer Aug 25, 2024
ee96ff5
Drop deprecated actx.{empty,zeros}{,_like}
inducer Aug 5, 2024
0c24aad
Fix a return type in ArgSizeLimitingPytatoLoopyPyOpenCLTarget
inducer Aug 6, 2024
02ab097
Separate doc page for actx abstraction from doc page for implementations
inducer Aug 5, 2024
c888489
Give up on precisely typing Array.__getitem__
inducer Jul 31, 2024
cd124ba
Fix doc upload script to properly sync deletions
inducer Jul 31, 2024
5d8158d
Deprecate with_container_arithmetic's bcast_numpy_array arg
kaushikcfd Sep 27, 2021
228ef16
Implements NumpyArrayContext
kaushikcfd Sep 26, 2021
1dc8c94
ArrayContainer fixes for numpy arrays as leaf classes
kaushikcfd Sep 26, 2021
51b46bd
arithmetic fixes to account for np.ndarray being a leaf array
kaushikcfd Sep 27, 2021
6308dc1
test NumpyArrayContext
kaushikcfd Sep 26, 2021
b5ea270
test tweaks for NumpyArrayContext
kaushikcfd Sep 27, 2021
80c0672
Numpy actx: add arange, linspace
matthiasdiener May 24, 2024
6d3b02a
Numpy actx: add zeros_like, reshape
matthiasdiener Jun 20, 2023
4125e02
Numpy actx: better freeze/thaw
matthiasdiener Jun 20, 2023
5da96a8
Numpy actx: Narrow array_types to non-obj arrays
inducer Jul 31, 2024
aa53572
Numpy actx: improve type annotations
inducer Jul 31, 2024
cf3f4fb
Array container arithemtic: drop deprecated fail-safe actx retrieval
inducer Jul 12, 2024
1af76ce
Skip tagging test for numpy actx
inducer Jul 31, 2024
b58e38e
Skip numpy conversion tests when using the numpy actx
inducer Aug 1, 2024
eca314f
Don't expect unflatten failure from numpy array for numpy actx
inducer Jul 31, 2024
4b4ee86
Container serialization: iterable -> sequence, plus type aliases
inducer Jul 31, 2024
3d36c07
Improve, type, fix array_equal across all array contexts
inducer Jul 31, 2024
58acd1f
Clarify that actx.array_types allows ABCs
inducer Jul 31, 2024
955820f
Rework dataclass array container arithmetic
inducer Jul 31, 2024
2bcc921
Switch to __array_ufunc__ in tests as a way to avoid numpy broadcasting
inducer Aug 6, 2024
66a202d
outer: disallow non-object numpy arrays
inducer Aug 6, 2024
6251b61
Fix ruff C409 failures
inducer Aug 11, 2024
e86a0e7
Introduce Bcast object-ified broacasting rules
inducer Aug 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
ArrayContainer,
ArrayContainerT,
NotAnArrayContainerError,
SerializationKey,
SerializedContainer,
deserialize_container,
get_container_context_opt,
get_container_context_recursively,
Expand All @@ -41,7 +43,15 @@
register_multivector_as_array_container,
serialize_container,
)
from .container.arithmetic import with_container_arithmetic
from .container.arithmetic import (
Bcast,
Bcast1Level,
Bcast2Levels,
Bcast3Levels,
BcastNLevels,
BcastUntilActxArray,
with_container_arithmetic,
)
from .container.dataclass import dataclass_array_container
from .container.traversal import (
flat_size_and_dtype,
Expand Down Expand Up @@ -78,6 +88,7 @@
tag_axes,
)
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .loopy import make_loopy_program
Expand All @@ -91,7 +102,6 @@


__all__ = (
"Array",
"Array",
"ArrayContainer",
"ArrayContainerT",
Expand All @@ -101,18 +111,26 @@
"ArrayOrContainerOrScalarT",
"ArrayOrContainerT",
"ArrayT",
"Bcast",
"Bcast1Level",
"Bcast2Levels",
"Bcast3Levels",
"BcastNLevels",
"BcastUntilActxArray",
"CommonSubexpressionTag",
"EagerJAXArrayContext",
"ElementwiseMapKernelTag",
"NotAnArrayContainerError",
"NumpyArrayContext",
"PyOpenCLArrayContext",
"PytatoJAXArrayContext",
"PytatoPyOpenCLArrayContext",
"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
"Scalar",
"Scalar",
"ScalarLike",
"SerializationKey",
"SerializedContainer",
"dataclass_array_container",
"deserialize_container",
"flat_size_and_dtype",
Expand Down Expand Up @@ -146,7 +164,7 @@
"to_numpy",
"unflatten",
"with_array_context",
"with_container_arithmetic"
"with_container_arithmetic",
)


Expand Down
56 changes: 43 additions & 13 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

Serialization/deserialization
-----------------------------

.. autoclass:: SerializationKey
.. autoclass:: SerializedContainer
.. autofunction:: is_array_container_type
.. autofunction:: serialize_container
.. autofunction:: deserialize_container
Expand Down Expand Up @@ -39,6 +42,14 @@
.. class:: ArrayOrContainerT

:canonical: arraycontext.ArrayOrContainerT

.. class:: SerializationKey

:canonical: arraycontext.SerializationKey

.. class:: SerializedContainer

:canonical: arraycontext.SerializedContainer
"""

from __future__ import annotations
Expand Down Expand Up @@ -69,12 +80,23 @@
"""

from functools import singledispatch
from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Hashable,
Iterable,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
)

# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
import numpy
import numpy as np
from typing_extensions import TypeAlias

from arraycontext.context import ArrayContext

Expand Down Expand Up @@ -142,23 +164,27 @@ class NotAnArrayContainerError(TypeError):
""":class:`TypeError` subclass raised when an array container is expected."""


SerializationKey: TypeAlias = Hashable
SerializedContainer: TypeAlias = Sequence[Tuple[SerializationKey, "ArrayOrContainer"]]


@singledispatch
def serialize_container(
ary: ArrayContainer) -> Iterable[Tuple[Any, ArrayOrContainer]]:
r"""Serialize the array container into an iterable over its components.
ary: ArrayContainer) -> SerializedContainer:
r"""Serialize the array container into a sequence over its components.

The order of the components and their identifiers are entirely under
the control of the container class. However, the order is required to be
deterministic, i.e. two calls to :func:`serialize_container` on
array containers of the same types with the same number of
sub-arrays must result in an iterable with the keys in the same
sub-arrays must result in a sequence with the keys in the same
order.

If *ary* is mutable, the serialization function is not required to ensure
that the serialization result reflects the array state at the time of the
call to :func:`serialize_container`.

:returns: an :class:`Iterable` of 2-tuples where the first
:returns: a :class:`Sequence` of 2-tuples where the first
entry is an identifier for the component and the second entry
is an array-like component of the :class:`ArrayContainer`.
Components can themselves be :class:`ArrayContainer`\ s, allowing
Expand All @@ -172,13 +198,13 @@ def serialize_container(
@singledispatch
def deserialize_container(
template: ArrayContainerT,
iterable: Iterable[Tuple[Any, Any]]) -> ArrayContainerT:
"""Deserialize an iterable into an array container.
serialized: SerializedContainer) -> ArrayContainerT:
"""Deserialize a sequence into an array container following a *template*.

:param template: an instance of an existing object that
can be used to aid in the deserialization. For a similar choice
see :attr:`~numpy.class.__array_finalize__`.
:param iterable: an iterable that mirrors the output of
:param serialized: a sequence that mirrors the output of
:meth:`serialize_container`.
"""
raise NotAnArrayContainerError(
Expand Down Expand Up @@ -218,7 +244,11 @@ def is_array_container(ary: Any) -> bool:
"cheaper option, see is_array_container_type.",
DeprecationWarning, stacklevel=2)
return (serialize_container.dispatch(ary.__class__)
is not serialize_container.__wrapped__) # type:ignore[attr-defined]
is not serialize_container.__wrapped__ # type:ignore[attr-defined]
# numpy values with scalar elements aren't array containers
and not (isinstance(ary, np.ndarray)
and ary.dtype.kind != "O")
)


@singledispatch
Expand All @@ -238,7 +268,7 @@ def get_container_context_opt(ary: ArrayContainer) -> Optional[ArrayContext]:

@serialize_container.register(np.ndarray)
def _serialize_ndarray_container(
ary: numpy.ndarray) -> Iterable[Tuple[Any, ArrayOrContainer]]:
ary: numpy.ndarray) -> SerializedContainer:
if ary.dtype.char != "O":
raise NotAnArrayContainerError(
f"cannot serialize '{type(ary).__name__}' with dtype '{ary.dtype}'")
Expand All @@ -252,20 +282,20 @@ def _serialize_ndarray_container(
for j in range(ary.shape[1])
]
else:
return np.ndenumerate(ary)
return list(np.ndenumerate(ary))


@deserialize_container.register(np.ndarray)
# https://github.com/python/mypy/issues/13040
def _deserialize_ndarray_container( # type: ignore[misc]
template: numpy.ndarray,
iterable: Iterable[Tuple[Any, ArrayOrContainer]]) -> numpy.ndarray:
serialized: SerializedContainer) -> numpy.ndarray:
# disallow subclasses
assert type(template) is np.ndarray
assert template.dtype.char == "O"

result = type(template)(template.shape, dtype=object)
for i, subary in iterable:
for i, subary in serialized:
result[i] = subary

return result
Expand Down
Loading
Loading