Skip to content

Commit

Permalink
Address type errors from more precisely typed pymbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 3, 2024
1 parent 0a6db51 commit 904544e
Show file tree
Hide file tree
Showing 13 changed files with 77 additions and 38 deletions.
3 changes: 3 additions & 0 deletions loopy/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import islpy as isl
from islpy import dim_type
from pymbolic.primitives import Variable
from pytools import memoize_method

from loopy.diagnostic import (
Expand Down Expand Up @@ -1669,6 +1670,8 @@ def _are_sub_array_refs_equivalent(
if len(sar1.swept_inames) != len(sar2.swept_inames):
return False

assert isinstance(sar1.subscript.aggregate, Variable)
assert isinstance(sar2.subscript.aggregate, Variable)
if sar1.subscript.aggregate.name != sar2.subscript.aggregate.name:
return False

Expand Down
10 changes: 9 additions & 1 deletion loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from immutables import Map

from loopy.codegen.result import CodeGenerationResult
from loopy.library.reduction import ReductionOpFunction
from loopy.translation_unit import CallablesTable, TranslationUnit


Expand Down Expand Up @@ -661,8 +662,15 @@ def generate_code_v2(t_unit: TranslationUnit) -> CodeGenerationResult:
ast=t_unit.target.get_device_ast_builder().ast_module.Collection(
callee_fdecls+[device_programs[0].ast]))] +
device_programs[1:])

def not_reduction_op(name: str | ReductionOpFunction) -> str:
assert isinstance(name, str)
return name

cgr = TranslationUnitCodeGenerationResult(
host_programs=host_programs,
host_programs={
not_reduction_op(name): prg
for name, prg in host_programs.items()},
device_programs=device_programs,
device_preambles=device_preambles)

Expand Down
3 changes: 2 additions & 1 deletion loopy/frontend/fortran/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class FortranExpressionParser(ExpressionParserBase):
(_not, pytools.lex.RE(r"\.not\.", re.I)),
(_and, pytools.lex.RE(r"\.and\.", re.I)),
(_or, pytools.lex.RE(r"\.or\.", re.I)),
] + ExpressionParserBase.lex_table
*ExpressionParserBase.lex_table,
]

def __init__(self, tree_walker):
self.tree_walker = tree_walker
Expand Down
17 changes: 11 additions & 6 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,16 +623,20 @@ def _parse_shape_or_strides(
return auto

if isinstance(x, str):
x = parse(x)
x_parsed = parse(x)
else:
x_parsed = x

if isinstance(x, list):
if isinstance(x_parsed, list):
raise ValueError("shape can't be a list")

if not isinstance(x, tuple):
if isinstance(x_parsed, tuple):
x_tup: tuple[ExpressionT, ...] = x_parsed
else:
assert x is not auto
x = (x,)
x_tup = (x_parsed,)

return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x)
return tuple(parse(xi) if isinstance(xi, str) else xi for xi in x_tup)


class ArrayBase(ImmutableRecord, Taggable):
Expand Down Expand Up @@ -1219,11 +1223,12 @@ def get_access_info(kernel: "LoopKernel",

import loopy as lp

def eval_expr_assert_integer_constant(i, expr):
def eval_expr_assert_integer_constant(i, expr) -> int:
from pymbolic.mapper.evaluator import UnknownVariableError
try:
result = eval_expr(expr)
except UnknownVariableError as e:
assert ary.dim_tags is not None
raise LoopyError("When trying to index the array '%s' along axis "
"%d (tagged '%s'), the index was not a compile-time "
"constant (but it has to be in order for code to be "
Expand Down
6 changes: 4 additions & 2 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,9 @@ def __init__(self,
if isinstance(assignee, str):
assignee = parse(assignee)
if isinstance(expression, str):
expression = parse(expression)
parsed_expression = parse(expression)
else:
parsed_expression = expression

from pymbolic.primitives import Lookup, Subscript, Variable

Expand All @@ -962,7 +964,7 @@ def __init__(self,
raise LoopyError("invalid lvalue '%s'" % assignee)

self.assignee = assignee
self.expression = expression
self.expression = parsed_expression

self.temp_var_type = _check_and_fix_temp_var_type(temp_var_type)
self.atomicity = atomicity
Expand Down
2 changes: 1 addition & 1 deletion loopy/target/c/codegen/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s):

return s

def rec(self, expr, type_context=None, needed_type: Optional[LoopyType] = None):
def rec(self, expr, type_context=None, needed_type: Optional[LoopyType] = None): # type: ignore[override]
result = RecursiveMapper.rec(self, expr, type_context)

if needed_type is None:
Expand Down
11 changes: 6 additions & 5 deletions loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from loopy.tools import LoopyKeyBuilder, caches
from loopy.translation_unit import TranslationUnit
from loopy.types import LoopyType, NumpyType
from loopy.typing import ExpressionT
from loopy.typing import ExpressionT, integer_expr_or_err
from loopy.version import DATA_MODEL_VERSION


Expand Down Expand Up @@ -187,7 +187,7 @@ def generate_integer_arg_finding_from_array_data(
if shape_i is not None:
equations.append(
_ArgFindingEquation(
lhs=var(arg.name).attr("shape").index(axis_nr),
lhs=var(arg.name).attr("shape")[axis_nr],
rhs=shape_i,
order=0,
based_on_names=frozenset({arg.name})))
Expand All @@ -198,7 +198,7 @@ def generate_integer_arg_finding_from_array_data(
equations.append(
_ArgFindingEquation(
lhs=var("_lpy_even_div")(
var(arg.name).attr("strides").index(axis_nr),
var(arg.name).attr("strides")[axis_nr],
arg.dtype.itemsize),
rhs=_str_to_expr(stride_i),
order=0,
Expand All @@ -211,8 +211,9 @@ def generate_integer_arg_finding_from_array_data(
# arguments.
equations.append(
_ArgFindingEquation(
lhs=(strides[axis_nr + 1]
* arg.shape[axis_nr + 1])
lhs=(integer_expr_or_err(strides[axis_nr + 1])
* integer_expr_or_err(
arg.shape[axis_nr + 1]))
if axis_nr + 1 < len(strides)
else 1,
rhs=_str_to_expr(stride_i),
Expand Down
28 changes: 16 additions & 12 deletions loopy/target/pyopencl_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,21 @@

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence

import numpy as np
from immutables import Map

from pytools import memoize_method
from pytools.codegen import CodeGenerator, Indentation

from loopy.codegen.result import CodeGenerationResult
from loopy.kernel import LoopKernel
from loopy.kernel.data import ArrayArg
from loopy.schedule.tools import KernelArgInfo
from loopy.target.execution import ExecutionWrapperGeneratorBase, ExecutorBase
from loopy.types import LoopyType
from loopy.typing import ExpressionT
from loopy.typing import ExpressionT, integer_expr_or_err


logger = logging.getLogger(__name__)
Expand All @@ -58,15 +59,15 @@ class PyOpenCLExecutionWrapperGenerator(ExecutionWrapperGeneratorBase):
pyopencl execution
"""

def __init__(self):
def __init__(self) -> None:
system_args = [
"_lpy_cl_kernels", "queue", "allocator=None", "wait_for=None",
# ignored if options.no_numpy
"out_host=None"
]
super().__init__(system_args)

def python_dtype_str_inner(self, dtype):
def python_dtype_str_inner(self, dtype: np.dtype) -> str:
import pyopencl.tools as cl_tools
# Test for types built into numpy. dtype.isbuiltin does not work:
# https://github.com/numpy/numpy/issues/4317
Expand All @@ -82,7 +83,7 @@ def python_dtype_str_inner(self, dtype):

# {{{ handle non-numpy args

def handle_non_numpy_arg(self, gen, arg):
def handle_non_numpy_arg(self, gen: CodeGenerator, arg: ArrayArg) -> None:
gen("if isinstance(%s, _lpy_np.ndarray):" % arg.name)
with Indentation(gen):
gen("# retain originally passed array")
Expand All @@ -108,7 +109,7 @@ def handle_non_numpy_arg(self, gen, arg):

def handle_alloc(
self, gen: CodeGenerator, arg: ArrayArg,
strify: Callable[[Union[ExpressionT, Tuple[ExpressionT]]], str],
strify: Callable[[ExpressionT], str],
skip_arg_checks: bool) -> None:
"""
Handle allocation of non-specified arguments for pyopencl execution
Expand Down Expand Up @@ -136,9 +137,10 @@ def handle_alloc(
for i in range(num_axes))
sym_shape = tuple(arg.shape[i] for i in range(num_axes))

size_expr = (sum(astrd*(alen-1)
for alen, astrd in zip(sym_shape, sym_ustrides))
+ 1)
size_expr = 1 + sum(
integer_expr_or_err(astrd)*(integer_expr_or_err(alen)-1)
for alen, astrd in zip(sym_shape, sym_ustrides)
)

gen("_lpy_size = %s" % strify(size_expr))
sym_strides = tuple(itemsize*s_i for s_i in sym_ustrides)
Expand All @@ -158,7 +160,7 @@ def handle_alloc(

# }}}

def target_specific_preamble(self, gen):
def target_specific_preamble(self, gen: CodeGenerator) -> None:
"""
Add default pyopencl imports to preamble
"""
Expand All @@ -170,7 +172,7 @@ def target_specific_preamble(self, gen):
from loopy.target.c.c_execution import DEF_EVEN_DIV_FUNCTION
gen.add_to_preamble(DEF_EVEN_DIV_FUNCTION)

def initialize_system_args(self, gen):
def initialize_system_args(self, gen: CodeGenerator) -> None:
"""
Initializes possibly empty system arguments
"""
Expand Down Expand Up @@ -259,7 +261,9 @@ def generate_output_handler(self, gen: CodeGenerator,

# }}}

def generate_host_code(self, gen, codegen_result):
def generate_host_code(
self, gen: CodeGenerator, codegen_result: CodeGenerationResult
) -> None:
gen.add_to_preamble(codegen_result.host_code())

def get_arg_pass(self, arg):
Expand Down
3 changes: 2 additions & 1 deletion loopy/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ def rename_callable(
raise LoopyError(f"callables named '{new_name}' already exists")

if new_name is None:
namegen = UniqueNameGenerator(t_unit.callables_table.keys())
namegen = UniqueNameGenerator(
{n for n in t_unit.callables_table if isinstance(n, str)})
new_name = namegen(old_name)

assert isinstance(new_name, str)
Expand Down
8 changes: 7 additions & 1 deletion loopy/transform/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from typing import List, Optional, Sequence

import numpy as np

import pymbolic.primitives as prim
from pytools import all_equal

Expand Down Expand Up @@ -82,7 +84,11 @@ def concatenate_arrays(
offsets[array_name] = axis_length
ary = kernel.temporary_variables[array_name]
assert isinstance(ary.shape, tuple)
axis_length += ary.shape[axis_nr]
shape_ax = ary.shape[axis_nr]
if not isinstance(shape_ax, (int, np.integer)):
raise TypeError(f"array '{array_name}' axis {axis_nr+1} (1-based) has "
"non-constant length.")
axis_length += shape_ax

new_ary = arrays[0]
if not isinstance(new_ary.shape, tuple):
Expand Down
3 changes: 1 addition & 2 deletions loopy/transform/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,8 +991,7 @@ class _BaseStorageInfo:
def _sym_max(a: ExpressionT, b: ExpressionT) -> ExpressionT:
from numbers import Number
if isinstance(a, Number) and isinstance(b, Number):
# https://github.com/python/mypy/issues/3186
return max(a, b) # type: ignore[call-overload]
return max(a, b)
else:
from pymbolic.primitives import Max
return Max((a, b))
Expand Down
3 changes: 2 additions & 1 deletion loopy/transform/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def func_map(i, func, args, allowed_nonsmoothness):
raise NotImplementedError("derivative of '%s'" % func.name)


class LoopyDiffMapper(DifferentiationMapper, RuleAwareIdentityMapper):
# It has a point: https://github.com/inducer/pymbolic/issues/149
class LoopyDiffMapper(DifferentiationMapper, RuleAwareIdentityMapper): # type: ignore[misc]
def __init__(self, rule_mapping_context, diff_context, diff_inames,
allowed_nonsmoothness=None):
RuleAwareIdentityMapper.__init__(self, rule_mapping_context)
Expand Down
18 changes: 13 additions & 5 deletions loopy/transform/precompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dataclasses import dataclass
from typing import FrozenSet, List, Optional, Sequence, Type, Union

import numpy as np
from immutables import Map

import islpy as isl
Expand Down Expand Up @@ -57,7 +58,13 @@
)
from loopy.translation_unit import CallablesTable, TranslationUnit
from loopy.types import LoopyType, ToLoopyTypeConvertible, to_loopy_type
from loopy.typing import ExpressionT, auto, not_none
from loopy.typing import (
ExpressionT,
auto,
integer_expr_or_err,
integer_or_err,
not_none,
)


# {{{ contains_subst_rule_invocation
Expand Down Expand Up @@ -527,7 +534,7 @@ def precompute_for_single_kernel(

if isinstance(subst_name_as_expr, TaggedVariable):
new_subst_name = subst_name_as_expr.name
new_subst_tag = subst_name_as_expr.tag
new_subst_tag, = subst_name_as_expr.tags
elif isinstance(subst_name_as_expr, Variable):
new_subst_name = subst_name_as_expr.name
new_subst_tag = None
Expand Down Expand Up @@ -568,7 +575,7 @@ def precompute_for_single_kernel(

for fpg in footprint_generators:
if isinstance(fpg, Variable):
args = ()
args: tuple[ExpressionT, ...] = ()
elif isinstance(fpg, Call):
args = fpg.parameters
else:
Expand Down Expand Up @@ -928,7 +935,7 @@ def add_assumptions(d):

storage_axis_subst_dict[
prior_storage_axis_name_dict.get(arg_name, arg_name)] = \
arg+base_index
arg+integer_expr_or_err(base_index)

rule_mapping_context = SubstitutionRuleMappingContext(
kernel.substitutions, kernel.get_var_name_generator())
Expand Down Expand Up @@ -1114,7 +1121,8 @@ def add_assumptions(d):
len(temp_var.shape), len(new_temp_shape)))

new_temp_shape = tuple(
max(i, ex_i)
# https://github.com/numpy/numpy/issues/27251
np.max(integer_or_err(i), integer_or_err(ex_i))
for i, ex_i in zip(new_temp_shape, temp_var.shape))

temp_var = temp_var.copy(shape=new_temp_shape)
Expand Down

0 comments on commit 904544e

Please sign in to comment.