-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add explicit scopes to CSE #150
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -777,8 +777,7 @@ def __eq__(self, other) -> bool: | |
depr_key = (type(self), "__eq__") | ||
if depr_key not in self._deprecation_warnings_issued: | ||
warn(f"Expression.__eq__ is used by {self.__class__}. This is deprecated. " | ||
"Use equality comparison supplied by expr_dataclass" | ||
"instead. " | ||
"Use equality comparison supplied by expr_dataclass instead. " | ||
"This will stop working in 2025.", | ||
DeprecationWarning, stacklevel=2) | ||
self._deprecation_warnings_issued.add(depr_key) | ||
|
@@ -994,7 +993,8 @@ def {cls.__name__}_eq(self, other): | |
|
||
return self.__class__ == other.__class__ and {comparison} | ||
|
||
cls.__eq__ = {cls.__name__}_eq | ||
if {hash}: | ||
cls.__eq__ = {cls.__name__}_eq | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also guards the Not sure if it's worth adding a separate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would these better default to |
||
|
||
|
||
def {cls.__name__}_hash(self): | ||
|
@@ -1024,8 +1024,9 @@ def {cls.__name__}_hash(self): | |
def {cls.__name__}_init_arg_names(self): | ||
depr_key = (type(self), "init_arg_names") | ||
if depr_key not in self._deprecation_warnings_issued: | ||
warn("__getinitargs__ is deprecated and will be removed in 2025. " | ||
"Use dataclasses.fields instead.", | ||
warn("Attribute 'init_arg_names' of {cls} is deprecated and will " | ||
"not have a default implementation starting from 2025. " | ||
"Use 'dataclasses.fields' instead.", | ||
DeprecationWarning, stacklevel=2) | ||
|
||
self._deprecation_warnings_issued.add(depr_key) | ||
|
@@ -1038,8 +1039,9 @@ def {cls.__name__}_init_arg_names(self): | |
def {cls.__name__}_getinitargs(self): | ||
depr_key = (type(self), "__getinitargs__") | ||
if depr_key not in self._deprecation_warnings_issued: | ||
warn("__getinitargs__ is deprecated and will be removed in 2025. " | ||
"Use dataclasses.fields instead.", | ||
warn("Method '__getinitargs__' of {cls} is deprecated and will " | ||
"not have a default implementation starting from 2025. " | ||
"Use 'dataclasses.fields' instead.", | ||
DeprecationWarning, stacklevel=2) | ||
|
||
self._deprecation_warnings_issued.add(depr_key) | ||
|
@@ -1947,89 +1949,103 @@ def is_zero(value): | |
return not is_nonzero(value) | ||
|
||
|
||
def wrap_in_cse(expr, prefix=None): | ||
def wrap_in_cse(expr: Expression, | ||
prefix: str | None = None, | ||
scope: str | None = None) -> Expression: | ||
if isinstance(expr, (Variable, Subscript)): | ||
return expr | ||
|
||
if scope is None: | ||
scope = cse_scope.EVALUATION | ||
|
||
if isinstance(expr, CommonSubexpression): | ||
if prefix is None: | ||
return expr | ||
|
||
if expr.prefix is None and type(expr) is CommonSubexpression: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use same logic as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, I'm not sure what you had in mind here? At least at a first glance, these seem to have different logic: |
||
return CommonSubexpression(expr.child, prefix) | ||
return CommonSubexpression(expr.child, prefix, scope) | ||
|
||
# existing prefix wins | ||
return expr | ||
|
||
else: | ||
return CommonSubexpression(expr, prefix) | ||
return CommonSubexpression(expr, prefix, scope) | ||
|
||
|
||
def make_common_subexpression(field, prefix=None, scope=None): | ||
"""Wrap *field* in a :class:`CommonSubexpression` with | ||
*prefix*. If *field* is a :mod:`numpy` object array, | ||
each individual entry is instead wrapped. If *field* is a | ||
:class:`pymbolic.geometric_algebra.MultiVector`, each | ||
coefficient is individually wrapped. | ||
def make_common_subexpression(expr: ExpressionT, | ||
prefix: str | None = None, | ||
scope: str | None = None) -> ExpressionT: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Left |
||
"""Wrap *expr* in a :class:`CommonSubexpression` with *prefix*. | ||
|
||
See :class:`CommonSubexpression` for the meaning of *prefix* | ||
and *scope*. | ||
If *expr* is a :mod:`numpy` object array, each individual entry is instead | ||
wrapped. If *expr* is a :class:`pymbolic.geometric_algebra.MultiVector`, each | ||
coefficient is individually wrapped. In general, the function tries to avoid | ||
re-wrapping existing :class:`CommonSubexpression` if the same scope is given. | ||
|
||
See :class:`CommonSubexpression` for the meaning of *prefix* and *scope*. The | ||
scope defaults to :attr:`cse_scope.EVALUATION`. | ||
""" | ||
|
||
if isinstance(field, CommonSubexpression) and ( | ||
scope is None or scope == cse_scope.EVALUATION | ||
or field.scope == scope): | ||
if scope is None: | ||
scope = cse_scope.EVALUATION | ||
|
||
if (isinstance(expr, CommonSubexpression) | ||
and (scope == cse_scope.EVALUATION or expr.scope == scope)): | ||
# Don't re-wrap | ||
return field | ||
return expr | ||
|
||
try: | ||
import numpy | ||
have_obj_array = ( | ||
isinstance(field, numpy.ndarray) | ||
and field.dtype.char == "O") | ||
logical_shape = ( | ||
field.shape | ||
if isinstance(field, numpy.ndarray) | ||
else ()) | ||
|
||
if isinstance(expr, numpy.ndarray) and expr.dtype.char == "O": | ||
is_obj_array = True | ||
logical_shape = expr.shape | ||
else: | ||
is_obj_array = False | ||
logical_shape = () | ||
except ImportError: | ||
have_obj_array = False | ||
is_obj_array = False | ||
logical_shape = () | ||
|
||
from pymbolic.geometric_algebra import MultiVector | ||
if isinstance(field, MultiVector): | ||
|
||
if isinstance(expr, MultiVector): | ||
new_data = {} | ||
for bits, coeff in field.data.items(): | ||
for bits, coeff in expr.data.items(): | ||
if prefix is not None: | ||
blade_str = field.space.blade_bits_to_str(bits, "") | ||
blade_str = expr.space.blade_bits_to_str(bits, "") | ||
component_prefix = prefix+"_"+blade_str | ||
else: | ||
component_prefix = None | ||
|
||
new_data[bits] = make_common_subexpression( | ||
coeff, component_prefix, scope) | ||
|
||
return MultiVector(new_data, field.space) | ||
return MultiVector(new_data, expr.space) | ||
|
||
elif is_obj_array and logical_shape != (): | ||
assert isinstance(expr, numpy.ndarray) | ||
|
||
elif have_obj_array and logical_shape != (): | ||
result = numpy.zeros(logical_shape, dtype=object) | ||
for i in numpy.ndindex(logical_shape): | ||
if prefix is not None: | ||
component_prefix = prefix+"_".join(str(i_i) for i_i in i) | ||
else: | ||
component_prefix = None | ||
|
||
if is_constant(field[i]): | ||
result[i] = field[i] | ||
if is_constant(expr[i]): | ||
result[i] = expr[i] | ||
else: | ||
result[i] = make_common_subexpression( | ||
field[i], component_prefix, scope) | ||
expr[i], component_prefix, scope) | ||
|
||
return result | ||
|
||
else: | ||
if is_constant(field): | ||
return field | ||
if is_constant(expr): | ||
return expr | ||
else: | ||
return CommonSubexpression(field, prefix, scope) | ||
return CommonSubexpression(expr, prefix, scope) | ||
|
||
|
||
def make_sym_vector(name, components, var_factory=Variable): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? Is this right?