Skip to content

Commit

Permalink
check_range_local: report failing values (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener authored Apr 21, 2022
1 parent 97d922b commit 08e439f
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions mirgecom/simutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@

from meshmode.dof_array import DOFArray

from typing import List
from grudge.discretization import DiscretizationCollection


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -271,17 +274,26 @@ def allsync(local_values, comm=None, op=None):
return global_reduce(local_values, op_string, comm=comm)


def check_range_local(discr, dd, field, min_value, max_value):
"""Check for any negative values."""
def check_range_local(discr: DiscretizationCollection, dd: str, field: DOFArray,
min_value: float, max_value: float) -> List[float]:
"""Return the values that are outside the range [min_value, max_value]."""
actx = field.array_context
return (
actx.to_numpy(op.nodal_min_loc(discr, dd, field)) < min_value
or actx.to_numpy(op.nodal_max_loc(discr, dd, field)) > max_value
)
local_min = np.asscalar(actx.to_numpy(op.nodal_min_loc(discr, dd, field)))
local_max = np.asscalar(actx.to_numpy(op.nodal_max_loc(discr, dd, field)))

failing_values = []

if local_min < min_value:
failing_values.append(local_min)
if local_max > max_value:
failing_values.append(local_max)

return failing_values


def check_naninf_local(discr, dd, field):
"""Check for any NANs or Infs in the field."""
def check_naninf_local(discr: DiscretizationCollection, dd: str,
field: DOFArray) -> bool:
"""Return True if there are any NaNs or Infs in the field."""
actx = field.array_context
s = actx.to_numpy(op.nodal_sum_loc(discr, dd, field))
return not np.isfinite(s)
Expand Down

0 comments on commit 08e439f

Please sign in to comment.