Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement actx.np.zeros #276

Merged
merged 1 commit into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
TypeVar,
Union,
)
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -249,10 +250,6 @@ class ArrayContext(ABC):

.. versionadded:: 2020.2

.. automethod:: empty
.. automethod:: zeros
.. automethod:: empty_like
.. automethod:: zeros_like
.. automethod:: from_numpy
.. automethod:: to_numpy
.. automethod:: call_loopy
Expand Down Expand Up @@ -293,9 +290,9 @@ class ArrayContext(ABC):
def __init__(self) -> None:
self.np = self._get_fake_numpy_namespace()

@abstractmethod
def _get_fake_numpy_namespace(self) -> Any:
from .fake_numpy import BaseFakeNumpyNamespace
return BaseFakeNumpyNamespace(self)
...

def __hash__(self) -> int:
raise TypeError(f"unhashable type: '{type(self).__name__}'")
Expand All @@ -306,22 +303,23 @@ def empty(self,
dtype: "np.dtype[Any]") -> Array:
pass

@abstractmethod
def zeros(self,
shape: Union[int, Tuple[int, ...]],
dtype: "np.dtype[Any]") -> Array:
pass
warn(f"{type(self).__name__}.zeros is deprecated and will stop "
"working in 2025. Use actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)

return self.np.zeros(shape, dtype)
alexfikl marked this conversation as resolved.
Show resolved Hide resolved

def empty_like(self, ary: Array) -> Array:
from warnings import warn
warn(f"{type(self).__name__}.empty_like is deprecated and will stop "
"working in 2023. Prefer actx.np.zeros_like instead.",
DeprecationWarning, stacklevel=2)

return self.empty(shape=ary.shape, dtype=ary.dtype)

def zeros_like(self, ary: Array) -> Array:
from warnings import warn
warn(f"{type(self).__name__}.zeros_like is deprecated and will stop "
"working in 2023. Use actx.np.zeros_like instead.",
DeprecationWarning, stacklevel=2)
Expand Down
11 changes: 10 additions & 1 deletion arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


import operator
from abc import ABC, abstractmethod
from typing import Any

import numpy as np
Expand All @@ -34,7 +35,7 @@

# {{{ BaseFakeNumpyNamespace

class BaseFakeNumpyNamespace:
class BaseFakeNumpyNamespace(ABC):
def __init__(self, array_context):
self._array_context = array_context
self.linalg = self._get_fake_numpy_linalg_namespace()
Expand Down Expand Up @@ -95,6 +96,14 @@ def _get_fake_numpy_linalg_namespace(self):
# "interp",
})

@abstractmethod
def zeros(self, shape, dtype):
...

@abstractmethod
def zeros_like(self, ary):
...

def conjugate(self, x):
# NOTE: conjugate distributes over object arrays, but it looks for a
# `conjugate` ufunc, while some implementations only have the shorter
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _wrapper(ary):
def empty(self, shape, dtype):
from warnings import warn
warn(f"{type(self).__name__}.empty is deprecated and will stop "
"working in 2023. Prefer actx.zeros instead.",
"working in 2023. Prefer actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)

import jax.numpy as jnp
Expand Down
3 changes: 3 additions & 0 deletions arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __getattr__(self, name):

# {{{ array creation routines

def zeros(self, shape, dtype):
return jnp.zeros(shape=shape, dtype=dtype)

def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _wrapper(ary):
def empty(self, shape, dtype):
from warnings import warn
warn(f"{type(self).__name__}.empty is deprecated and will stop "
"working in 2023. Prefer actx.zeros instead.",
"working in 2023. Prefer actx.np.zeros instead.",
DeprecationWarning, stacklevel=2)

import arraycontext.impl.pyopencl.taggable_cl_array as tga
Expand Down
6 changes: 6 additions & 0 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
rec_multimap_reduce_array_container,
)
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace


Expand All @@ -60,6 +61,11 @@ def _get_fake_numpy_linalg_namespace(self):

# {{{ array creation routines

def zeros(self, shape, dtype) -> TaggableCLArray:
import arraycontext.impl.pyopencl.taggable_cl_array as tga
return tga.zeros(self._array_context.queue, shape, dtype,
allocator=self._array_context.allocator)

def empty_like(self, ary):
from warnings import warn
warn(f"{type(self._array_context).__name__}.np.empty_like is "
Expand Down
3 changes: 3 additions & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __getattr__(self, name):

# {{{ array creation routines

def zeros(self, shape, dtype):
return pt.zeros(shape, dtype)

def zeros_like(self, ary):
def _zeros_like(array):
return self._array_context.zeros(
Expand Down
12 changes: 6 additions & 6 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def test_leaf_array_type_broadcasting(actx_factory):
# test support for https://github.com/inducer/arraycontext/issues/49
actx = actx_factory()

foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, )))
foo = Foo(DOFArray(actx, (actx.np.zeros(3, dtype=np.float64) + 41, )))
bar = foo + 4
baz = foo + actx.from_numpy(4*np.ones((3, )))
qux = actx.from_numpy(4*np.ones((3, ))) + foo
Expand Down Expand Up @@ -1510,7 +1510,7 @@ def _twice(x):

actx = actx_factory()
ones = actx.thaw(actx.freeze(
actx.zeros(shape=(10, 4), dtype=np.float64) + 1
actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1
))
np.testing.assert_allclose(actx.to_numpy(_twice(ones)),
actx.to_numpy(actx.compile(_twice)(ones)))
Expand Down Expand Up @@ -1573,7 +1573,7 @@ def test_taggable_cl_array_tags(actx_factory):
def test_to_numpy_on_frozen_arrays(actx_factory):
# See https://github.com/inducer/arraycontext/issues/159
actx = actx_factory()
u = actx.freeze(actx.zeros(10, dtype="float64")+1)
u = actx.freeze(actx.np.zeros(10, dtype="float64")+1)
np.testing.assert_allclose(actx.to_numpy(u), 1)
np.testing.assert_allclose(actx.to_numpy(u), 1)

Expand All @@ -1592,7 +1592,7 @@ class ExampleTag(Tag):
ary = tag_axes(actx, {0: ExampleTag()},
actx.tag(
ExampleTag(),
actx.zeros((20, 20), dtype=np.float64)))
actx.np.zeros((20, 20), dtype=np.float64)))

assert ary.tags_of_type(ExampleTag)
assert ary.axes[0].tags_of_type(ExampleTag)
Expand All @@ -1606,11 +1606,11 @@ def test_compile_anonymous_function(actx_factory):
actx = actx_factory()
f = actx.compile(lambda x: 2*x+40)
np.testing.assert_allclose(
actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)
f = actx.compile(partial(lambda x: 2*x+40))
np.testing.assert_allclose(
actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))),
actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))),
42)


Expand Down
Loading