Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add axis tags for reshapes, etc., in direction connection code #379

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 49 additions & 30 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
import loopy as lp
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag,
DiscretizationElementAxisTag, DiscretizationDOFAxisTag)
DiscretizationElementAxisTag, DiscretizationDOFAxisTag,
DiscretizationDOFPickListAxisTag)
from pytools import memoize_in, keyed_memoize_method
from arraycontext import (
ArrayContext, ArrayT, ArrayOrContainerT, NotAnArrayContainerError,
Expand Down Expand Up @@ -166,12 +167,14 @@ def _global_from_element_indices(
np_full_from_element_indices[~np_from_el_present] = 0

from_el_present = actx.freeze(
actx.tag(NameHint("from_el_present"),
actx.from_numpy(
np_from_el_present.astype(np.int8))))
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_present"),
actx.from_numpy(
np_from_el_present.astype(np.int8)))))
full_from_element_indices = actx.freeze(
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(np_full_from_element_indices)))
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(np_full_from_element_indices))))

self._global_from_element_indices_cache = (
from_el_present, full_from_element_indices)
Expand Down Expand Up @@ -553,17 +556,22 @@ def _per_target_group_pick_info(
_FromGroupPickData(
from_group_index=source_group_index,
dof_pick_lists=actx.freeze(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would suffice to just tag the element axis in the indirect access where this is used, below.

actx.tag(NameHint("dof_pick_lists"),
actx.from_numpy(dof_pick_lists))),
actx.tag_axis(0, DiscretizationDOFPickListAxisTag(),
actx.tag(NameHint("dof_pick_lists"),
actx.from_numpy(dof_pick_lists)))),
dof_pick_list_indices=actx.freeze(
actx.tag(NameHint("dof_pick_list_indices"),
actx.from_numpy(dof_pick_list_indices))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("dof_pick_list_indices"),
actx.from_numpy(dof_pick_list_indices)))),
from_el_present=actx.freeze(
actx.tag(NameHint("from_el_present"),
actx.from_numpy(from_el_present.astype(np.int8)))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_present"),
actx.from_numpy(
from_el_present.astype(np.int8))))),
from_element_indices=actx.freeze(
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(from_el_indices))),
actx.tag_axis(0, DiscretizationElementAxisTag(),
actx.tag(NameHint("from_el_indices"),
actx.from_numpy(from_el_indices)))),
is_surjective=from_el_present.all()
))

Expand Down Expand Up @@ -732,25 +740,27 @@ def group_pick_knl(is_surjective: bool):
group_pick_info = None

if group_pick_info is not None:
group_array_contributions = []

if actx.permits_advanced_indexing and not _force_use_loopy:
for fgpd in group_pick_info:
from_element_indices = actx.thaw(fgpd.from_element_indices)

if ary[fgpd.from_group_index].size:
grp_ary_contrib = ary[fgpd.from_group_index][
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
actx.thaw(fgpd.dof_pick_lists)[
actx.thaw(fgpd.dof_pick_list_indices)]
]
actx, from_element_indices, (-1, 1))),
actx.thaw(fgpd.dof_pick_lists)[
actx.thaw(fgpd.dof_pick_list_indices)]
]

if not fgpd.is_surjective:
from_el_present = actx.thaw(fgpd.from_el_present)
grp_ary_contrib = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1))),
grp_ary_contrib,
0)

Expand Down Expand Up @@ -800,8 +810,10 @@ def group_pick_knl(is_surjective: bool):
mat = self._resample_matrix(actx, i_tgrp, i_batch)
if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1))),
actx.einsum("ij,ej->ei",
mat, grp_ary[from_element_indices]),
0)
Expand All @@ -822,11 +834,15 @@ def group_pick_knl(is_surjective: bool):

if actx.permits_advanced_indexing and not _force_use_loopy:
batch_result = actx.np.where(
_reshape_and_preserve_tags(
actx, from_el_present, (-1, 1)),
from_vec[
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1)),
actx, from_el_present, (-1, 1))),
from_vec[
tag_axes(actx, {
1: DiscretizationDOFAxisTag()},
_reshape_and_preserve_tags(
actx, from_element_indices, (-1, 1))),
pick_list],
0)
else:
Expand All @@ -853,10 +869,13 @@ def group_pick_knl(is_surjective: bool):
else:
# If no batched data at all, return zeros for this
# particular group array
group_array = actx.zeros(
group_array = tag_axes(actx, {
0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
actx.zeros(
shape=(self.to_discr.groups[i_tgrp].nelements,
self.to_discr.groups[i_tgrp].nunit_dofs),
dtype=ary.entry_dtype)
dtype=ary.entry_dtype))

group_arrays.append(group_array)

Expand Down
10 changes: 10 additions & 0 deletions meshmode/transform_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
.. autoclass:: DiscretizationDOFAxisTag
.. autoclass:: DiscretizationAmbientDimAxisTag
.. autoclass:: DiscretizationTopologicalDimAxisTag
.. autoclass:: DiscretizationDOFPickListAxisTag
"""

__copyright__ = """
Expand Down Expand Up @@ -121,3 +122,12 @@ class DiscretizationTopologicalDimAxisTag(DiscretizationDimAxisTag):
Array dimensions tagged with this tag type describe an axis indexing over
the discretization's physical coordinate dimensions.
"""


@tag_dataclass
class DiscretizationDOFPickListAxisTag(DiscretizationEntityAxisTag):
"""
Array dimensions tagged with this tag type describe an axis indexing over
DOF pick lists in
:class:`meshmode.discretization.connection.DirectDiscretizationConnection`.
"""