Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF Operator class #1014

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions nibabel/arrayops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import operator

import numpy as np

from .orientations import aff2axcodes
# support_np_type = (
# np.int8,
# np.int64,
# np.float16,
# np.float32,
# np.float64,
# np.complex128)


class OperableImage:
def _binop(self, val, *, op):
"""Apply operator to Nifti1Image.

Arithmetic and logical operation on Nifti image.
Currently support: +, -, *, /, //, &, |
The nifit image should contain the same header information and affine.
Images should be the same shape.

Parameters
----------
op :
Python operator.
"""
affine, header = self.affine, self.header
self_, val_ = _input_validation(self, val)
# numerical operator should work work

if op.__name__ in ["add", "sub", "mul", "truediv", "floordiv"]:
dataobj = op(self_, val_)
if op.__name__ in ["and_", "or_"]:
self_ = self_.astype(bool)
val_ = val_.astype(bool)
dataobj = op(self_, val_).astype(int)
return self.__class__(dataobj, affine, header)

def _unop(self, *, op):
"""
Parameters
----------
op :
Python operator.
"""
# _type_check(self)
if op.__name__ in ["pos", "neg", "abs"]:
dataobj = op(np.asanyarray(self.dataobj))
return self.__class__(dataobj, self.affine, self.header)

def __add__(self, other):
return self._binop(other, op=operator.__add__)

def __sub__(self, other):
return self._binop(other, op=operator.__sub__)

def __mul__(self, other):
return self._binop(other, op=operator.__mul__)

def __truediv__(self, other):
return self._binop(other, op=operator.__truediv__)

def __floordiv__(self, other):
return self._binop(other, op=operator.__floordiv__)

def __and__(self, other):
return self._binop(other, op=operator.__and__)

def __or__(self, other):
return self._binop(other, op=operator.__or__)

def __pos__(self):
return self._unop(op=operator.__pos__)

def __neg__(self):
return self._unop(op=operator.__neg__)

def __abs__(self):
return self._unop(op=operator.__abs__)


def _input_validation(self, val):
"""Check images orientation, affine, and shape muti-images operation."""
# _type_check(self)
if isinstance(val, self.__class__):
# _type_check(val)
# Check orientations are the same
if aff2axcodes(self.affine) != aff2axcodes(val.affine):
raise ValueError("Two images should have the same orientation")
# Check affine
if (self.affine != val.affine).any():
raise ValueError("Two images should have the same affine.")

# Check shape.
if self.shape[:3] != val.shape[:3]:
raise ValueError("Two images should have the same shape except "
"the time dimension.")

# if 4th dim exist in a image,
# reshape the 3d image to ensure valid projection
ndims = (len(self.shape), len(val.shape))
if 4 not in ndims:
self_ = np.asanyarray(self.dataobj)
val_ = np.asanyarray(val.dataobj)
return self_, val_

reference = None
imgs = []
for ndim, img in zip(ndims, (self, val)):
img_ = np.asanyarray(img.dataobj)
if ndim == 3:
reference = tuple(list(img.shape) + [1])
img_ = np.reshape(img_, reference)
imgs.append(img_)
return imgs
else:
self_ = np.asanyarray(self.dataobj)
val_ = val
return self_, val_


# def _type_check(*args):
# """Ensure image contains correct nifti data type."""
# # Check types
# dtypes = [img.get_data_dtype().type for img in args]
# # check allowed dtype based on the operator
# if set(support_np_type).union(dtypes) == 0:
# raise ValueError("Image contains illegal datatype for Nifti1Image.")
3 changes: 2 additions & 1 deletion nibabel/nifti1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .spm99analyze import SpmAnalyzeHeader
from .casting import have_binary128
from .pydicom_compat import have_dicom, pydicom as pdcm
from .arrayops import OperableImage

# nifti1 flat header definition for Analyze-like first 348 bytes
# first number in comments indicates offset in file header in bytes
Expand Down Expand Up @@ -2011,7 +2012,7 @@ def as_reoriented(self, ornt):
return img


class Nifti1Image(Nifti1Pair, SerializableImage):
class Nifti1Image(Nifti1Pair, SerializableImage, OperableImage):
""" Class for single file NIfTI1 format image
"""
header_class = Nifti1Header
Expand Down
94 changes: 94 additions & 0 deletions nibabel/tests/test_arrayops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
from .. import Nifti1Image
from numpy.testing import assert_array_equal
import pytest


def test_binary_operations():
data1 = np.random.rand(5, 5, 2)
data2 = np.random.rand(5, 5, 2)
data1[0, 0, :] = 0
img1 = Nifti1Image(data1, np.eye(4))
img2 = Nifti1Image(data2, np.eye(4))

output = img1 + 2
assert_array_equal(output.dataobj, data1 + 2)

output = img1 + img2
assert_array_equal(output.dataobj, data1 + data2)

output = img1 + img2 + img2
assert_array_equal(output.dataobj, data1 + data2 + data2)

output = img1 - img2
assert_array_equal(output.dataobj, data1 - data2)

output = img1 * img2
assert_array_equal(output.dataobj, data1 * data2)

output = img1 / img2
assert_array_equal(output.dataobj, data1 / data2)

output = img1 // img2
assert_array_equal(output.dataobj, data1 // data2)

output = img2 / img1
assert_array_equal(output.dataobj, data2 / data1)

output = img2 // img1
assert_array_equal(output.dataobj, data2 // data1)

output = img1 & img2
assert_array_equal(output.dataobj, (data1.astype(bool) & data2.astype(bool)).astype(int))

output = img1 | img2
assert_array_equal(output.dataobj, (data1.astype(bool) | data2.astype(bool)).astype(int))


def test_binary_operations_4d():
data1 = np.random.rand(5, 5, 2, 3)
data2 = np.random.rand(5, 5, 2)
img1 = Nifti1Image(data1, np.eye(4))
img2 = Nifti1Image(data2, np.eye(4))
data2_ = np.reshape(data2, (5, 5, 2, 1))

output = img1 * img2
assert_array_equal(output.dataobj, data1 * data2_)


def test_unary_operations():
data = np.random.rand(5, 5, 2) - 0.5
img = Nifti1Image(data, np.eye(4))

output = +img
assert_array_equal(output.dataobj, +data)

output = -img
assert_array_equal(output.dataobj, -data)

output = abs(img)
assert_array_equal(output.dataobj, abs(data))


def test_error_catching():
data1 = np.random.rand(5, 5, 1)
data2 = np.random.rand(5, 5, 2)
img1 = Nifti1Image(data1, np.eye(4))
img2 = Nifti1Image(data2, np.eye(4))
with pytest.raises(ValueError, match=r'should have the same shape'):
img1 + img2

data1 = np.random.rand(5, 5, 2)
data2 = np.random.rand(5, 5, 2)
img1 = Nifti1Image(data1, np.eye(4) * 2)
img2 = Nifti1Image(data2, np.eye(4))
with pytest.raises(ValueError, match=r'should have the same affine'):
img1 + img2

data = np.random.rand(5, 5, 2)
aff1 = [[0,1,0,10],[-1,0,0,20],[0,0,1,30],[0,0,0,1]]
aff2 = np.eye(4)
img1 = Nifti1Image(data, aff1)
img2 = Nifti1Image(data, aff2)
with pytest.raises(ValueError, match=r'should have the same orientation'):
img1 + img2