Skip to content

Commit

Permalink
Merge pull request #384 from jcapriot/random_warnings
Browse files Browse the repository at this point in the history
Warn for non-repeatable random tests in a testing environment
  • Loading branch information
jcapriot authored Nov 5, 2024
2 parents 636e3ff + 58ff79e commit 93d48ec
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
25 changes: 25 additions & 0 deletions discretize/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
assert_isadjoint
""" # NOQA D205

import warnings

import numpy as np
import scipy.sparse as sp

Expand Down Expand Up @@ -81,6 +83,23 @@
_happiness_rng = np.random.default_rng()


def _warn_random_test():
stack = inspect.stack()
in_pytest = any(x[0].f_globals["__name__"].startswith("_pytest.") for x in stack)
in_nosetest = any(x[0].f_globals["__name__"].startswith("nose.") for x in stack)

if in_pytest or in_nosetest:
test = "pytest" if in_pytest else "nosetest"
warnings.warn(
f"You are running a {test} without setting a random seed, the results might not be"
"repeatable. For repeatable tests please pass an argument to `random seed` that is"
"not `None`.",
UserWarning,
stacklevel=3,
)
return in_pytest or in_nosetest


def setup_mesh(mesh_type, nC, nDim, random_seed=None):
"""Generate arbitrary mesh for testing.
Expand Down Expand Up @@ -110,6 +129,8 @@ def setup_mesh(mesh_type, nC, nDim, random_seed=None):
A discretize mesh of class specified by the input argument *mesh_type*
"""
if "random" in mesh_type:
if random_seed is None:
_warn_random_test()
rng = np.random.default_rng(random_seed)
if "TensorMesh" in mesh_type:
if "uniform" in mesh_type:
Expand Down Expand Up @@ -649,6 +670,8 @@ def check_derivative(
x0 = mkvc(x0)

if dx is None:
if random_seed is None:
_warn_random_test()
rng = np.random.default_rng(random_seed)
dx = rng.standard_normal(len(x0))

Expand Down Expand Up @@ -867,6 +890,8 @@ def assert_isadjoint(
"""
__tracebackhide__ = True

if random_seed is None:
_warn_random_test()
rng = np.random.default_rng(random_seed)

def random(size, iscomplex):
Expand Down
27 changes: 26 additions & 1 deletion tests/base/test_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import subprocess
import numpy as np
import scipy.sparse as sp
from discretize.tests import assert_isadjoint, check_derivative, assert_expected_order
from discretize.tests import (
assert_isadjoint,
check_derivative,
assert_expected_order,
_warn_random_test,
setup_mesh,
)


class TestAssertIsAdjoint:
Expand Down Expand Up @@ -166,3 +172,22 @@ def test_import_time():

# Currently we check t < 1.0s.
assert float(out.stderr.decode("utf-8")[:-1]) < 1.0


def test_random_test_warning():

match = r"You are running a pytest without setting a random seed.*"
with pytest.warns(UserWarning, match=match):
_warn_random_test()

def simple_deriv(x):
return np.sin(x), lambda y: np.cos(x) * y

with pytest.warns(UserWarning, match=match):
check_derivative(simple_deriv, np.zeros(10), plotIt=False)

with pytest.warns(UserWarning, match=match):
setup_mesh("randomTensorMesh", 10, 1)

with pytest.warns(UserWarning, match=match):
assert_isadjoint(lambda x: x, lambda x: x, 5, 5)

0 comments on commit 93d48ec

Please sign in to comment.