Skip to content

Commit

Permalink
mention arraycontext.flatten and unflatten in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 20, 2021
1 parent 58c74c8 commit c119f9a
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions meshmode/dof_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag)
from arraycontext import (
ArrayContext, make_loopy_program,
ArrayContainer, with_container_arithmetic,
ArrayContext, make_loopy_program, with_container_arithmetic,
serialize_container, deserialize_container,
thaw as _thaw, freeze as _freeze,
rec_map_array_container, rec_multimap_array_container,
mapped_over_array_containers, multimapped_over_array_containers)
from arraycontext.container import ArrayOrContainerT

__doc__ = """
.. autoclass:: DOFArray
Expand Down Expand Up @@ -453,7 +453,7 @@ def _flatten(grp_ary):
return actx.np.concatenate([_flatten(grp_ary) for grp_ary in ary])


def flatten(ary: ArrayContainer, *, strict: bool = True) -> ArrayContainer:
def flatten(ary: ArrayOrContainerT, *, strict: bool = True) -> ArrayOrContainerT:
r"""Convert all :class:`DOFArray`\ s into a "flat" array of degrees of
freedom, where the resulting type of the array is given by the
:attr:`DOFArray.array_context`.
Expand All @@ -463,7 +463,10 @@ def flatten(ary: ArrayContainer, *, strict: bool = True) -> ArrayContainer:
index fastest.
Recurses into the :class:`~arraycontext.ArrayContainer` for all
:class:`DOFArray`\ s.
:class:`DOFArray`\ s and flattens them, but retains the
remaining structure of the container as is. For a more general method,
that produces a one-dimensional "flat" array of the entire container see
:func:`arraycontext.flatten`.
:arg strict: if *True*, only :class:`DOFArray`\ s are allowed as leaves
in the container *ary*. If *False*, any non-:class:`DOFArray` are
Expand Down Expand Up @@ -524,10 +527,10 @@ def _unflatten_group_sizes(discr, ndofs_per_element_per_group):


def unflatten(
actx: ArrayContext, discr, ary: ArrayContainer,
actx: ArrayContext, discr, ary: ArrayOrContainerT,
ndofs_per_element_per_group: Optional[Iterable[int]] = None, *,
strict: bool = True,
) -> ArrayContainer:
) -> ArrayOrContainerT:
r"""Convert all "flat" arrays returned by :func:`flatten` back to
:class:`DOFArray`\ s.
Expand All @@ -537,6 +540,9 @@ def unflatten(
of degrees of freedom, as given by `ndofs_per_element_per_group`
(or via *discr*).
Note that this method only restores flattened :class:`DOFArray`\ s. For
a more general version see :func:`arraycontext.unflatten`.
:arg ndofs_per_element: if given, an iterable of numbers representing
the number of degrees of freedom per element, overriding the numbers
provided by the element groups in *discr*. May be used (for example)
Expand All @@ -558,9 +564,9 @@ def _unflatten(subary):


def unflatten_like(
actx: ArrayContext, ary: ArrayContainer, prototype: ArrayContainer, *,
actx: ArrayContext, ary: ArrayOrContainerT, prototype: ArrayOrContainerT, *,
strict: bool = True,
) -> ArrayContainer:
) -> ArrayOrContainerT:
r"""Convert all "flat" arrays returned by :func:`flatten` back to
:class:`DOFArray`\ s based on a *prototype* container.
Expand All @@ -572,6 +578,8 @@ def unflatten_like(
:class:`~meshmode.discretization.Discretization`\ s within the same
container.
For a more general version, see :func:`arraycontext.unflatten`.
:arg prototype: an array container with the same structure as *ary*,
whose :class:`DOFArray` leaves are used to get the sizes to
unflatten *ary*.
Expand Down Expand Up @@ -615,10 +623,10 @@ def _unflatten_like(_ary, _prototype):
return _unflatten_like(ary, prototype)


def flatten_to_numpy(actx: ArrayContext, ary: ArrayContainer, *,
strict: bool = True) -> ArrayContainer:
def flatten_to_numpy(actx: ArrayContext, ary: ArrayOrContainerT, *,
strict: bool = True) -> ArrayOrContainerT:
r"""Converts all :class:`DOFArray`\ s into "flat" :class:`numpy.ndarray`\ s
using :func:`flatten`.
using :func:`flatten` and :meth:`arraycontext.ArrayContext.to_numpy`.
"""
def _flatten_to_numpy(subary):
if isinstance(subary, DOFArray) and subary.array_context is None:
Expand All @@ -630,12 +638,13 @@ def _flatten_to_numpy(subary):


def unflatten_from_numpy(
actx: ArrayContext, discr, ary: ArrayContainer,
actx: ArrayContext, discr, ary: ArrayOrContainerT,
ndofs_per_element_per_group: Optional[Iterable[int]] = None, *,
strict: bool = True,
) -> ArrayContainer:
) -> ArrayOrContainerT:
r"""Takes "flat" arrays returned by :func:`flatten_to_numpy` and
reconstructs the corresponding :class:`DOFArray`\ s using :func:`unflatten`.
reconstructs the corresponding :class:`DOFArray`\ s using :func:`unflatten`
and :meth:`arraycontext.ArrayContext.from_numpy`.
"""
group_shapes, group_starts = _unflatten_group_sizes(
discr, ndofs_per_element_per_group)
Expand Down

0 comments on commit c119f9a

Please sign in to comment.