diff --git a/nibabel/tests/test_utils.py b/nibabel/tests/test_utils.py new file mode 100644 index 0000000000..f6f11a469b --- /dev/null +++ b/nibabel/tests/test_utils.py @@ -0,0 +1,22 @@ +""" Test for utils module +""" + +import numpy as np + +from nibabel.utils import to_scalar + +from nose.tools import assert_equal, assert_true, assert_false, assert_raises + + +def test_to_scalar(): + for pass_thru in (2, 2.3, 'foo', b'foo', [], (), [2], (2,), object()): + assert_true(to_scalar(pass_thru) is pass_thru) + for arr_contents in (2, 2.3, 'foo', b'foo'): + arr = np.array(arr_contents) + out = to_scalar(arr) + assert_false(to_scalar(arr) is arr) + assert_equal(out, arr_contents) + # Promote to 1 and 2D and check contents + assert_equal(to_scalar(np.atleast_1d(arr)), arr_contents) + assert_equal(to_scalar(np.atleast_2d(arr)), arr_contents) + assert_raises(ValueError, to_scalar, np.array([1, 2])) diff --git a/nibabel/utils.py b/nibabel/utils.py new file mode 100644 index 0000000000..b5a5f4fb82 --- /dev/null +++ b/nibabel/utils.py @@ -0,0 +1,20 @@ +""" Code support routines, not otherwise classified +""" + +def to_scalar(val): + """ Return scalar representation of `val` + + Return scalar value from numpy array, or pass through value if not numpy + array. + + Parameters + ---------- + val : object + numpy array or other object. + + Returns + ------- + out : object + Result of ``val.item()`` if `val` has an ``item`` method, else `val`. + """ + return val.item() if hasattr(val, 'item') else val