diff --git a/loopy/kernel/array.py b/loopy/kernel/array.py index c6b4565a4..199523699 100644 --- a/loopy/kernel/array.py +++ b/loopy/kernel/array.py @@ -45,6 +45,7 @@ from warnings import warn import numpy as np # noqa +from typing_extensions import TypeAlias from pytools import ImmutableRecord from pytools.tag import Tag, Taggable @@ -52,7 +53,7 @@ from loopy.diagnostic import LoopyError from loopy.tools import is_integer from loopy.types import LoopyType -from loopy.typing import ExpressionT, ShapeType +from loopy.typing import ExpressionT, ShapeType, auto if TYPE_CHECKING: @@ -593,29 +594,33 @@ def convert_computed_to_fixed_dim_tags(name, num_user_axes, num_target_axes, # {{{ array base class (for arguments and temporary arrays) -def _pymbolic_parse_if_necessary(x): - if isinstance(x, str): - from pymbolic import parse - return parse(x) - else: - return x +ToShapeLikeConvertible: TypeAlias = (Tuple[ExpressionT | str, ...] + | ExpressionT | type[auto] | str | tuple[str, ...]) -def _parse_shape_or_strides(x): - import loopy as lp +def _parse_shape_or_strides( + x: ToShapeLikeConvertible, + ) -> ShapeType | type[auto]: + from pymbolic import parse + if x == "auto": - warn("use of 'auto' as a shape or stride won't work " - "any more--use loopy.auto instead", - stacklevel=3) - x = _pymbolic_parse_if_necessary(x) - if isinstance(x, lp.auto): - return x - assert not isinstance(x, list) + raise ValueError("use of 'auto' as a shape or stride won't work " + "any more--use loopy.auto instead") + + if x is auto: + return auto + + if isinstance(x, str): + x = parse(x) + + if isinstance(x, list): + raise ValueError("shape can't be a list") + if not isinstance(x, tuple): - assert x is not lp.auto + assert x is not auto x = (x,) - return tuple(_pymbolic_parse_if_necessary(xi) for xi in x) + return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x) class ArrayBase(ImmutableRecord, Taggable): diff --git a/loopy/kernel/data.py b/loopy/kernel/data.py index 22c9ce562..aec7c6d97 100644 --- a/loopy/kernel/data.py +++ b/loopy/kernel/data.py @@ -682,11 +682,28 @@ class TemporaryVariable(ArrayBase): "_base_storage_access_may_be_aliasing", ) - def __init__(self, name, dtype=None, shape=auto, address_space=None, - dim_tags=None, offset=0, dim_names=None, strides=None, order=None, - base_indices=None, storage_shape=None, - base_storage=None, initializer=None, read_only=False, - _base_storage_access_may_be_aliasing=False, **kwargs): + def __init__( + self, + name: str, + dtype: ToLoopyTypeConvertible = None, + shape: Union[ShapeType, Type["auto"], None] = auto, + address_space: Union[AddressSpace, Type[auto], None] = None, + dim_tags: Optional[Sequence[ArrayDimImplementationTag]] = None, + offset: Union[ExpressionT, str, None] = 0, + dim_names: Optional[Tuple[str, ...]] = None, + strides: Optional[Tuple[ExpressionT, ...]] = None, + order: str | None = None, + + base_indices: Optional[Tuple[ExpressionT, ...]] = None, + storage_shape: ShapeType | None = None, + + base_storage: Optional[str] = None, + initializer: Optional[np.ndarray] = None, + read_only: bool = False, + + _base_storage_access_may_be_aliasing: bool = False, + **kwargs: Any + ) -> None: """ :arg dtype: :class:`loopy.auto` or a :class:`numpy.dtype` :arg shape: :class:`loopy.auto` or a shape tuple @@ -696,12 +713,6 @@ def __init__(self, name, dtype=None, shape=auto, address_space=None, if address_space is None: address_space = auto - if address_space is None: - raise LoopyError( - "temporary variable '%s': " - "address_space must not be None" - % name) - if initializer is None: pass elif isinstance(initializer, np.ndarray): @@ -736,7 +747,12 @@ def __init__(self, name, dtype=None, shape=auto, address_space=None, if order is None: order = "C" - if base_indices is None and shape is not auto: + if shape is not None: + from loopy.kernel.array import _parse_shape_or_strides + shape = _parse_shape_or_strides(shape) + + if base_indices is None and shape is not auto and shape is not None: + assert isinstance(shape, tuple) base_indices = (0,) * len(shape) if not read_only and initializer is not None: @@ -775,7 +791,7 @@ def __init__(self, name, dtype=None, shape=auto, address_space=None, _base_storage_access_may_be_aliasing), **kwargs) - def copy(self, **kwargs): + def copy(self, **kwargs: Any) -> TemporaryVariable: address_space = kwargs.pop("address_space", None) if address_space is not None: @@ -784,15 +800,23 @@ def copy(self, **kwargs): return super().copy(**kwargs) @property - def nbytes(self): - shape = self.shape + def nbytes(self) -> ExpressionT: if self.storage_shape is not None: shape = self.storage_shape + else: + if self.shape is None: + raise ValueError("shape is None") + if self.shape is auto: + raise ValueError("shape is auto") + shape = cast(Tuple[ExpressionT], self.shape) + + if self.dtype is None: + raise ValueError("data type is indeterminate") from pytools import product return product(si for si in shape)*self.dtype.itemsize - def __str__(self): + def __str__(self) -> str: if self.address_space is auto: aspace_str = "auto" else: diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index 41a82d4b5..a577c115e 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -402,8 +402,6 @@ def __init__(self, # The Taggable constructor call does extra validation. tags=tags) - Taggable.__init__(self, tags) - def get_copy_kwargs(self, **kwargs): passed_depends_on = "depends_on" in kwargs @@ -938,7 +936,8 @@ def __init__(self, predicates: Optional[FrozenSet[str]] = None, tags: Optional[FrozenSet[Tag]] = None, temp_var_type: Union[ - Type[_not_provided], None, LoopyOptional] = _not_provided, + Type[_not_provided], None, LoopyOptional, + LoopyType] = _not_provided, atomicity: Tuple[VarAtomicity, ...] = (), *, depends_on: Union[FrozenSet[str], str, None] = None,