Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port to new pymbolic #242

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@
"pytools": ("https://documen.tician.de/pytools", None),
"scipy": ("https://docs.scipy.org/doc/scipy", None),
"sumpy": ("https://documen.tician.de/sumpy", None),
"sympy": ("https://docs.sympy.org/latest/", None),
}
10 changes: 6 additions & 4 deletions pytential/linalg/direct_solver_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ def map_int_g(self, expr):
if name not in source_args
}

return expr.copy(target_kernel=target_kernel,
source_kernels=source_kernels,
densities=self.rec(expr.densities),
kernel_arguments=kernel_arguments)
from dataclasses import replace
return replace(expr,
target_kernel=target_kernel,
source_kernels=source_kernels,
densities=self.rec(expr.densities),
kernel_arguments=kernel_arguments)

# }}}

Expand Down
30 changes: 16 additions & 14 deletions pytential/linalg/skeletonization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@

from dataclasses import dataclass
from typing import (
Any, Callable, Dict, Hashable, Optional, Sequence, Tuple, Union)
Any, Callable, Dict, Hashable, Optional, Sequence, Tuple, Type, Union)

import numpy as np

from arraycontext import PyOpenCLArrayContext, Array

from pytential import GeometryCollection, sym
from pytential.symbolic.matrix import ClusterMatrixBuilderBase
from pytential.linalg.utils import IndexList, TargetAndSourceClusterList
from pytential.linalg.proxy import ProxyGeneratorBase, ProxyClusterGeometryData
from pytential.linalg.direct_solver_symbolic import (
Expand Down Expand Up @@ -136,7 +137,7 @@ def prg():
lang_version=lp.MOST_RECENT_LANGUAGE_VERSION,
)

return knl
return knl.executor(actx.context)

waa = bind(places, sym.weights_and_area_elements(
places.ambient_dim, dofdesc=domain))(actx)
Expand Down Expand Up @@ -253,17 +254,17 @@ class SkeletonizationWrangler:
domains: Tuple[sym.DOFDescriptor, ...]
context: Dict[str, Any]

neighbor_cluster_builder: Callable[..., np.ndarray]
neighbor_cluster_builder: Type[ClusterMatrixBuilderBase]

# target skeletonization
weighted_targets: bool
target_proxy_exprs: np.ndarray
proxy_target_cluster_builder: Callable[..., np.ndarray]
proxy_target_cluster_builder: Type[ClusterMatrixBuilderBase]

# source skeletonization
weighted_sources: bool
source_proxy_exprs: np.ndarray
proxy_source_cluster_builder: Callable[..., np.ndarray]
proxy_source_cluster_builder: Type[ClusterMatrixBuilderBase]

@property
def nrows(self) -> int:
Expand Down Expand Up @@ -386,35 +387,36 @@ def make_skeletonization_wrangler(

# internal
_weighted_proxy: Optional[Union[bool, Tuple[bool, bool]]] = None,
_proxy_source_cluster_builder: Optional[Callable[..., np.ndarray]] = None,
_proxy_target_cluster_builder: Optional[Callable[..., np.ndarray]] = None,
_neighbor_cluster_builder: Optional[Callable[..., np.ndarray]] = None,
_proxy_source_cluster_builder: Optional[Type[ClusterMatrixBuilderBase]] = None,
_proxy_target_cluster_builder: Optional[Type[ClusterMatrixBuilderBase]] = None,
_neighbor_cluster_builder: Optional[Type[ClusterMatrixBuilderBase]] = None,
) -> SkeletonizationWrangler:
if context is None:
context = {}

# {{{ setup expressions

try:
exprs = list(exprs)
lpot_exprs = list(exprs)
except TypeError:
exprs = [exprs]
lpot_exprs = [exprs]

try:
input_exprs = list(input_exprs)
except TypeError:
assert not isinstance(input_exprs, Sequence)
input_exprs = [input_exprs]

from pytential.symbolic.execution import _prepare_auto_where, _prepare_domains

auto_where = _prepare_auto_where(auto_where, places)
domains = _prepare_domains(len(input_exprs), places, domains, auto_where[0])

exprs = prepare_expr(places, exprs, auto_where)
prepared_lpot_exprs = prepare_expr(places, lpot_exprs, auto_where)
source_proxy_exprs = prepare_proxy_expr(
places, exprs, (auto_where[0], PROXY_SKELETONIZATION_TARGET))
places, prepared_lpot_exprs, (auto_where[0], PROXY_SKELETONIZATION_TARGET))
target_proxy_exprs = prepare_proxy_expr(
places, exprs, (PROXY_SKELETONIZATION_SOURCE, auto_where[1]))
places, prepared_lpot_exprs, (PROXY_SKELETONIZATION_SOURCE, auto_where[1]))

# }}}

Expand Down Expand Up @@ -449,7 +451,7 @@ def make_skeletonization_wrangler(

return SkeletonizationWrangler(
# operator
exprs=exprs,
exprs=prepared_lpot_exprs,
input_exprs=tuple(input_exprs),
domains=tuple(domains),
context=context,
Expand Down
2 changes: 1 addition & 1 deletion pytential/linalg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def prg():
lang_version=MOST_RECENT_LANGUAGE_VERSION)

knl = lp.split_iname(knl, "icluster", 128, outer_tag="g.0")
return knl
return knl.executor(actx.context)

@memoize_in(mindex, (make_index_cluster_cartesian_product, "index_product"))
def _product():
Expand Down
4 changes: 2 additions & 2 deletions pytential/qbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ def preprocess_optemplate(self, name, discretizations, expr):
def op_group_features(self, expr):
from pytential.utils import sort_arrays_together
result = (
expr.source, *sort_arrays_together(expr.source_kernels,
expr.densities, key=str)
expr.source,
*sort_arrays_together(expr.source_kernels, expr.densities, key=str)
)

return result
Expand Down
12 changes: 7 additions & 5 deletions pytential/qbx/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def element_prop_threshold_checker(self):
lang_version=MOST_RECENT_LANGUAGE_VERSION)

knl = lp.split_iname(knl, "ielement", 128, inner_tag="l.0", outer_tag="g.0")
return knl
return knl.executor(self.array_context.context)

def get_wrangler(self):
return RefinerWrangler(self.array_context, self)
Expand Down Expand Up @@ -388,8 +388,8 @@ def check_sufficient_source_quadrature_resolution(self,
sym.ElementwiseMax(
sym._source_danger_zone_radii(
stage2_density_discr.ambient_dim,
dofdesc=sym.QBX_SOURCE_STAGE2),
dofdesc=sym.GRANULARITY_ELEMENT)
dofdesc=sym.as_dofdesc(sym.QBX_SOURCE_STAGE2)),
dofdesc=sym.as_dofdesc(sym.GRANULARITY_ELEMENT))
)(self.array_context), self.array_context)
unwrap_args = AreaQueryElementwiseTemplate.unwrap_args

Expand Down Expand Up @@ -633,7 +633,8 @@ def _refine_qbx_stage1(lpot_source, density_discr,
quad_resolution_by_element = bind(stage1_density_discr,
sym.ElementwiseMax(
sym._quad_resolution(stage1_density_discr.ambient_dim),
dofdesc=sym.GRANULARITY_ELEMENT))(actx)
dofdesc=sym.as_dofdesc(sym.GRANULARITY_ELEMENT)
))(actx)

violates_kernel_length_scale = \
wrangler.check_element_prop_threshold(
Expand All @@ -653,7 +654,8 @@ def _refine_qbx_stage1(lpot_source, density_discr,
scaled_max_curvature_by_element = bind(stage1_density_discr,
sym.ElementwiseMax(
sym._scaled_max_curvature(stage1_density_discr.ambient_dim),
dofdesc=sym.GRANULARITY_ELEMENT))(actx)
dofdesc=sym.as_dofdesc(sym.GRANULARITY_ELEMENT)
))(actx)

violates_scaled_max_curv = \
wrangler.check_element_prop_threshold(
Expand Down
2 changes: 1 addition & 1 deletion pytential/qbx/target_assoc.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ def make_target_flags(self, target_discrs_and_qbx_sides):
return target_flags

def make_default_target_association(self, ntargets):
target_to_center = self.array_context.zeros(ntargets, dtype=np.int32)
target_to_center = self.array_context.np.zeros(ntargets, dtype=np.int32)
target_to_center.fill(-1)
target_to_center.finish()

Expand Down
14 changes: 6 additions & 8 deletions pytential/symbolic/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from functools import reduce
from typing import (
AbstractSet, Any, Collection, Tuple, Dict, Hashable, List,
Optional, Sequence, Set, Iterator)
Optional, Sequence, Set, Iterator, Union)

import numpy as np

from pymbolic.primitives import cse_scope, Expression, Variable
from pymbolic.primitives import cse_scope, Expression, Variable, Subscript
from sumpy.kernel import Kernel

from pytential.symbolic.primitives import (
Expand All @@ -53,7 +53,7 @@ def get_assignees(self) -> Set[str]:
raise NotImplementedError(
f"get_assignees for '{self.__class__.__name__}'")

def get_dependencies(self, dep_mapper: DependencyMapper) -> Set[Expression]:
def get_dependencies(self, dep_mapper: DependencyMapper) -> Set[Variable]:
raise NotImplementedError(
f"get_dependencies for '{self.__class__.__name__}'")

Expand Down Expand Up @@ -81,7 +81,7 @@ def __post_init__(self):
def get_assignees(self):
return set(self.names)

def get_dependencies(self, dep_mapper: DependencyMapper) -> Set[Expression]:
def get_dependencies(self, dep_mapper: DependencyMapper) -> Set[Variable]:
from operator import or_
deps = reduce(or_, (dep_mapper(expr) for expr in self.exprs))

Expand Down Expand Up @@ -189,7 +189,7 @@ class ComputePotential(Statement):
def get_assignees(self):
return {o.name for o in self.outputs}

def get_dependencies(self, dep_mapper: DependencyMapper) -> Set[Expression]:
def get_dependencies(self, dep_mapper: DependencyMapper) -> Set[Variable]:
result = dep_mapper(self.densities[0])
for density in self.densities[1:]:
result.update(dep_mapper(density))
Expand Down Expand Up @@ -546,9 +546,7 @@ def make_assign(

def assign_to_new_var(
self, expr: Expression, priority: int = 0, prefix: Optional[str] = None,
) -> Variable:
from pymbolic.primitives import Subscript

) -> Union[Variable, Subscript]:
# Observe that the only things that can be legally subscripted
# are variables. All other expressions are broken down into
# their scalar components.
Expand Down
2 changes: 2 additions & 0 deletions pytential/symbolic/dof_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def as_dofdesc(desc: "DOFDescriptorLike") -> "DOFDescriptor":

# {{{ type annotations

EMPTY_DESCRIPTOR = DOFDescriptor(geometry=None)

DiscretizationStages = Union[
Type[QBX_SOURCE_STAGE1],
Type[QBX_SOURCE_STAGE2],
Expand Down
51 changes: 25 additions & 26 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
THE SOFTWARE.
"""

from dataclasses import replace
from functools import reduce

from pymbolic.mapper.stringifier import (
Expand Down Expand Up @@ -129,9 +130,7 @@ def map_int_g(self, expr):
if not changed:
return expr

return expr.copy(
densities=densities,
kernel_arguments=kernel_arguments)
return replace(expr, densities=densities, kernel_arguments=kernel_arguments)

def map_interpolation(self, expr):
operand = self.rec(expr.operand)
Expand Down Expand Up @@ -261,10 +260,7 @@ def map_int_g(self, expr):
if not changed:
return expr

return expr.copy(
densities=densities,
kernel_arguments=kernel_arguments,
)
return replace(expr, densities=densities, kernel_arguments=kernel_arguments)

def map_common_subexpression(self, expr):
child = self.rec(expr.child)
Expand Down Expand Up @@ -522,9 +518,9 @@ def map_product(self, expr):

def map_int_g(self, expr):
from sumpy.kernel import AxisTargetDerivative
return expr.copy(
target_kernel=AxisTargetDerivative(
self.ambient_axis, expr.target_kernel))

target_kernel = AxisTargetDerivative(self.ambient_axis, expr.target_kernel)
return replace(expr, target_kernel=target_kernel)


class DerivativeSourceAndNablaComponentCollector(
Expand Down Expand Up @@ -570,15 +566,15 @@ def map_int_g(self, expr):
raise ValueError(
"Unregularized evaluation does not support one-sided limits")

expr = expr.copy(
qbx_forced_limit=None,
densities=self.rec(expr.densities),
kernel_arguments={
name: self.rec(arg_expr)
for name, arg_expr in expr.kernel_arguments.items()
})

return expr
return replace(
expr,
qbx_forced_limit=None,
densities=self.rec(expr.densities),
kernel_arguments={
name: self.rec(arg_expr)
for name, arg_expr in expr.kernel_arguments.items()
}
)

# }}}

Expand Down Expand Up @@ -626,7 +622,7 @@ def map_num_reference_derivative(self, expr):

def map_int_g(self, expr):
if expr.target.discr_stage is None:
expr = expr.copy(target=expr.target.to_stage1())
expr = replace(expr, target=expr.target.to_stage1())

if expr.source.discr_stage is not None:
return expr
Expand All @@ -638,16 +634,18 @@ def map_int_g(self, expr):

from_dd = expr.source.to_stage1()
to_dd = from_dd.to_quad_stage2()
densities = [prim.interp(from_dd, to_dd, self.rec(density)) for
density in expr.densities]
densities = tuple(
prim.interp(from_dd, to_dd, self.rec(density)) for
density in expr.densities)

from_dd = from_dd.copy(discr_stage=self.from_discr_stage)
kernel_arguments = {
name: prim.interp(from_dd, to_dd,
self.rec(self.tagger(arg_expr)))
for name, arg_expr in expr.kernel_arguments.items()}

return expr.copy(
return replace(
expr,
densities=densities,
kernel_arguments=kernel_arguments,
source=to_dd)
Expand Down Expand Up @@ -678,7 +676,8 @@ def map_int_g(self, expr):

is_self = source_discr is target_discr

expr = expr.copy(
expr = replace(
expr,
densities=self.rec(expr.densities),
kernel_arguments={
name: self.rec(arg_expr)
Expand Down Expand Up @@ -707,8 +706,8 @@ def map_int_g(self, expr):

if expr.qbx_forced_limit == "avg":
return 0.5*(
expr.copy(qbx_forced_limit=+1)
+ expr.copy(qbx_forced_limit=-1))
replace(expr, qbx_forced_limit=+1)
+ replace(expr, qbx_forced_limit=-1))
else:
return expr

Expand Down
Loading
Loading