diff --git a/grudge/array_context.py b/grudge/array_context.py index a3c92b40..2cd265b9 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -35,6 +35,7 @@ TYPE_CHECKING, Mapping, Tuple, Any, Callable, Optional, Type, FrozenSet) from dataclasses import dataclass +from pytools import to_identifier from pytools.tag import Tag from meshmode.array_context import ( PyOpenCLArrayContext as _PyOpenCLArrayContextBase, @@ -162,11 +163,6 @@ class MPIBasedArrayContext: # {{{ distributed + pytato -def _to_identifier(s: str) -> str: - # Only allow digits, letters, and underscores in identifiers - return "".join(ch for ch in s if ch.isalnum() or ch == "_") - - @dataclass(frozen=True) class _DistributedPartProgramID: f: Callable[..., Any] @@ -175,9 +171,9 @@ class _DistributedPartProgramID: def __str__(self): name = getattr(self.f, "__name__", "anonymous") if not name.isidentifier(): - name = _to_identifier(name) + name = to_identifier(name) - part = _to_identifier(str(self.part_id)) + part = to_identifier(str(self.part_id)) if part: return f"{name}_part{part}" else: diff --git a/grudge/dof_desc.py b/grudge/dof_desc.py index 8d5089c2..10c7e2e3 100644 --- a/grudge/dof_desc.py +++ b/grudge/dof_desc.py @@ -85,16 +85,7 @@ from meshmode.mesh import ( BTAG_PARTITION, BTAG_ALL, BTAG_REALLY_ALL, BTAG_NONE, BoundaryTag) - -# {{{ _to_identifier - -def _to_identifier(name: str) -> str: - if not name.isidentifier(): - return "".join(ch for ch in name if ch.isidentifier()) - else: - return name - -# }}} +from pytools import to_identifier # {{{ volume tags @@ -357,24 +348,24 @@ def as_identifier(self) -> str: if isinstance(vtag, type): vtag = vtag.__name__.replace("VTAG_", "").lower() elif isinstance(vtag, str): - vtag = _to_identifier(vtag) + vtag = to_identifier(vtag) else: - vtag = _to_identifier(str(vtag)) + vtag = to_identifier(str(vtag)) dom_id = f"v_{vtag}" elif isinstance(self.domain_tag, BoundaryDomainTag): btag = self.domain_tag.tag if isinstance(btag, type): btag = btag.__name__.replace("BTAG_", "").lower() elif isinstance(btag, str): - btag = _to_identifier(btag) + btag = to_identifier(btag) else: - btag = _to_identifier(str(btag)) + btag = to_identifier(str(btag)) dom_id = f"b_{btag}" else: raise ValueError(f"unexpected domain tag: '{self.domain_tag}'") if isinstance(self.discretization_tag, str): - discr_id = _to_identifier(self.discretization_tag) + discr_id = to_identifier(self.discretization_tag) elif issubclass(self.discretization_tag, DISCR_TAG_QUAD): discr_id = "_quad" elif self.discretization_tag is DISCR_TAG_BASE: