Skip to content

Commit

Permalink
Type TemporaryVariable methods
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Aug 6, 2024
1 parent 51d59f4 commit 123caba
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 37 deletions.
41 changes: 23 additions & 18 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@
from warnings import warn

import numpy as np # noqa
from typing_extensions import TypeAlias

from pytools import ImmutableRecord
from pytools.tag import Tag, Taggable

from loopy.diagnostic import LoopyError
from loopy.tools import is_integer
from loopy.types import LoopyType
from loopy.typing import ExpressionT, ShapeType
from loopy.typing import ExpressionT, ShapeType, auto


if TYPE_CHECKING:
Expand Down Expand Up @@ -593,29 +594,33 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes,

# {{{ array base class (for arguments and temporary arrays)

def _pymbolic_parse_if_necessary(x):
if isinstance(x, str):
from pymbolic import parse
return parse(x)
else:
return x
ToShapeLikeConvertible: TypeAlias = (Tuple[ExpressionT | str, ...]
| ExpressionT | type[auto] | str | tuple[str, ...])


def _parse_shape_or_strides(x):
import loopy as lp
def _parse_shape_or_strides(
x: ToShapeLikeConvertible,
) -> ShapeType | type[auto]:
from pymbolic import parse

if x == "auto":
warn("use of 'auto' as a shape or stride won't work "
"any more--use loopy.auto instead",
stacklevel=3)
x = _pymbolic_parse_if_necessary(x)
if isinstance(x, lp.auto):
return x
assert not isinstance(x, list)
raise ValueError("use of 'auto' as a shape or stride won't work "
"any more--use loopy.auto instead")

if x is auto:
return auto

if isinstance(x, str):
x = parse(x)

if isinstance(x, list):
raise ValueError("shape can't be a list")

if not isinstance(x, tuple):
assert x is not lp.auto
assert x is not auto
x = (x,)

return tuple(_pymbolic_parse_if_necessary(xi) for xi in x)
return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x)


class ArrayBase(ImmutableRecord, Taggable):
Expand Down
56 changes: 40 additions & 16 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,28 @@ class TemporaryVariable(ArrayBase):
"_base_storage_access_may_be_aliasing",
)

def __init__(self, name, dtype=None, shape=auto, address_space=None,
dim_tags=None, offset=0, dim_names=None, strides=None, order=None,
base_indices=None, storage_shape=None,
base_storage=None, initializer=None, read_only=False,
_base_storage_access_may_be_aliasing=False, **kwargs):
def __init__(
self,
name: str,
dtype: ToLoopyTypeConvertible = None,
shape: Union[ShapeType, Type["auto"], None] = auto,
address_space: Union[AddressSpace, Type[auto], None] = None,
dim_tags: Optional[Sequence[ArrayDimImplementationTag]] = None,
offset: Union[ExpressionT, str, None] = 0,
dim_names: Optional[Tuple[str, ...]] = None,
strides: Optional[Tuple[ExpressionT, ...]] = None,
order: str | None = None,

base_indices: Optional[Tuple[ExpressionT, ...]] = None,
storage_shape: ShapeType | None = None,

base_storage: Optional[str] = None,
initializer: Optional[np.ndarray] = None,
read_only: bool = False,

_base_storage_access_may_be_aliasing: bool = False,
**kwargs: Any
) -> None:
"""
:arg dtype: :class:`loopy.auto` or a :class:`numpy.dtype`
:arg shape: :class:`loopy.auto` or a shape tuple
Expand All @@ -696,12 +713,6 @@ def __init__(self, name, dtype=None, shape=auto, address_space=None,
if address_space is None:
address_space = auto

if address_space is None:
raise LoopyError(
"temporary variable '%s': "
"address_space must not be None"
% name)

if initializer is None:
pass
elif isinstance(initializer, np.ndarray):
Expand Down Expand Up @@ -736,7 +747,12 @@ def __init__(self, name, dtype=None, shape=auto, address_space=None,
if order is None:
order = "C"

if base_indices is None and shape is not auto:
if shape is not None:
from loopy.kernel.array import _parse_shape_or_strides
shape = _parse_shape_or_strides(shape)

if base_indices is None and shape is not auto and shape is not None:
assert isinstance(shape, tuple)
base_indices = (0,) * len(shape)

if not read_only and initializer is not None:
Expand Down Expand Up @@ -775,7 +791,7 @@ def __init__(self, name, dtype=None, shape=auto, address_space=None,
_base_storage_access_may_be_aliasing),
**kwargs)

def copy(self, **kwargs):
def copy(self, **kwargs: Any) -> TemporaryVariable:
address_space = kwargs.pop("address_space", None)

if address_space is not None:
Expand All @@ -784,15 +800,23 @@ def copy(self, **kwargs):
return super().copy(**kwargs)

@property
def nbytes(self):
shape = self.shape
def nbytes(self) -> ExpressionT:
if self.storage_shape is not None:
shape = self.storage_shape
else:
if self.shape is None:
raise ValueError("shape is None")
if self.shape is auto:
raise ValueError("shape is auto")
shape = cast(Tuple[ExpressionT], self.shape)

if self.dtype is None:
raise ValueError("data type is indeterminate")

from pytools import product
return product(si for si in shape)*self.dtype.itemsize

def __str__(self):
def __str__(self) -> str:
if self.address_space is auto:
aspace_str = "auto"
else:
Expand Down
5 changes: 2 additions & 3 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,6 @@ def __init__(self,
# The Taggable constructor call does extra validation.
tags=tags)

Taggable.__init__(self, tags)

def get_copy_kwargs(self, **kwargs):
passed_depends_on = "depends_on" in kwargs

Expand Down Expand Up @@ -938,7 +936,8 @@ def __init__(self,
predicates: Optional[FrozenSet[str]] = None,
tags: Optional[FrozenSet[Tag]] = None,
temp_var_type: Union[
Type[_not_provided], None, LoopyOptional] = _not_provided,
Type[_not_provided], None, LoopyOptional,
LoopyType] = _not_provided,
atomicity: Tuple[VarAtomicity, ...] = (),
*,
depends_on: Union[FrozenSet[str], str, None] = None,
Expand Down

0 comments on commit 123caba

Please sign in to comment.