Skip to content

Commit

Permalink
do not force host transfers when computing norms
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Oct 19, 2021
1 parent d9cd93c commit 58c74c8
Showing 1 changed file with 47 additions and 18 deletions.
65 changes: 47 additions & 18 deletions meshmode/dof_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,32 @@ def _unflatten_from_numpy(subary):

# {{{ flat_norm

def _reduce_norm(actx, arys, ord):
from numbers import Number
from functools import reduce

# NOTE: these are ordered by an expected usage frequency
if ord == 2:
return actx.np.sqrt(sum(subary*subary for subary in arys))
elif ord == np.inf:
return reduce(actx.np.maximum, arys)
elif ord == -np.inf:
return reduce(actx.np.minimum, arys)
elif isinstance(ord, Number) and ord > 0:
return sum(subary**ord for subary in arys)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")


def flat_norm(ary, ord=None) -> Any:
r"""Return an element-wise :math:`\ell^{\text{ord}}` norm of *ary*.
:arg ary: may be a :class:`DOFArray` or a
Unlike :attr:`arraycontext.ArrayContext.np`, this function handles
:class:`DOFArray`\ s by taking a norm of their flattened values
(in the sense of :func:`flatten`) regardless of how the group arrays
are stored.
:arg ary: may be a :class:`DOFArray` or an
:class:`~arraycontext.ArrayContainer` containing them.
"""

Expand All @@ -678,24 +700,31 @@ def flat_norm(ary, ord=None) -> Any:
ord = 2

from arraycontext import is_array_container
actx = None

def _rec(_ary):
nonlocal actx

if isinstance(_ary, DOFArray):
if actx is None:
actx = _ary.array_context
else:
assert actx is _ary.array_context

assert actx is not None

return _reduce_norm(actx, [
actx.np.linalg.norm(actx.np.ravel(subary, order="A"), ord=ord)
for _, subary in serialize_container(_ary)
], ord=ord)

elif is_array_container(_ary):
arys = [_rec(subary) for _, subary in serialize_container(_ary)]
return _reduce_norm(actx, arys, ord=ord)

raise TypeError(f"unsupported array type: '{type(_ary).__name__}'")

from arraycontext.fake_numpy import _scalar_list_norm
if isinstance(ary, DOFArray):
actx = ary.array_context
return _scalar_list_norm(
[
actx.np.linalg.norm(actx.np.ravel(subary, order="A"), ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)

elif is_array_container(ary):
return _scalar_list_norm(
[flat_norm(subary, ord=ord)
for _, subary in serialize_container(ary)],
ord=ord)

raise TypeError(
f"unsupported array type passed to flat_norm: '{type(ary).__name__}'")
return _rec(ary)

# }}}

Expand Down

0 comments on commit 58c74c8

Please sign in to comment.