From 58c74c83c4993542fe28c7b315cb99f177119548 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 17 Oct 2021 10:55:07 -0500 Subject: [PATCH] do not force host transfers when computing norms --- meshmode/dof_array.py | 65 +++++++++++++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/meshmode/dof_array.py b/meshmode/dof_array.py index f6ba325ab..8c283c624 100644 --- a/meshmode/dof_array.py +++ b/meshmode/dof_array.py @@ -663,10 +663,32 @@ def _unflatten_from_numpy(subary): # {{{ flat_norm +def _reduce_norm(actx, arys, ord): + from numbers import Number + from functools import reduce + + # NOTE: these are ordered by an expected usage frequency + if ord == 2: + return actx.np.sqrt(sum(subary*subary for subary in arys)) + elif ord == np.inf: + return reduce(actx.np.maximum, arys) + elif ord == -np.inf: + return reduce(actx.np.minimum, arys) + elif isinstance(ord, Number) and ord > 0: + return sum(subary**ord for subary in arys)**(1/ord) + else: + raise NotImplementedError(f"unsupported value of 'ord': {ord}") + + def flat_norm(ary, ord=None) -> Any: r"""Return an element-wise :math:`\ell^{\text{ord}}` norm of *ary*. - :arg ary: may be a :class:`DOFArray` or a + Unlike :attr:`arraycontext.ArrayContext.np`, this function handles + :class:`DOFArray`\ s by taking a norm of their flattened values + (in the sense of :func:`flatten`) regardless of how the group arrays + are stored. + + :arg ary: may be a :class:`DOFArray` or an :class:`~arraycontext.ArrayContainer` containing them. """ @@ -678,24 +700,31 @@ def flat_norm(ary, ord=None) -> Any: ord = 2 from arraycontext import is_array_container + actx = None + + def _rec(_ary): + nonlocal actx + + if isinstance(_ary, DOFArray): + if actx is None: + actx = _ary.array_context + else: + assert actx is _ary.array_context + + assert actx is not None + + return _reduce_norm(actx, [ + actx.np.linalg.norm(actx.np.ravel(subary, order="A"), ord=ord) + for _, subary in serialize_container(_ary) + ], ord=ord) + + elif is_array_container(_ary): + arys = [_rec(subary) for _, subary in serialize_container(_ary)] + return _reduce_norm(actx, arys, ord=ord) + + raise TypeError(f"unsupported array type: '{type(_ary).__name__}'") - from arraycontext.fake_numpy import _scalar_list_norm - if isinstance(ary, DOFArray): - actx = ary.array_context - return _scalar_list_norm( - [ - actx.np.linalg.norm(actx.np.ravel(subary, order="A"), ord=ord) - for _, subary in serialize_container(ary)], - ord=ord) - - elif is_array_container(ary): - return _scalar_list_norm( - [flat_norm(subary, ord=ord) - for _, subary in serialize_container(ary)], - ord=ord) - - raise TypeError( - f"unsupported array type passed to flat_norm: '{type(ary).__name__}'") + return _rec(ary) # }}}