Skip to content

Commit b29b783

Browse files
committed
ADD unary operator
1 parent 9f46fba commit b29b783

File tree

2 files changed

+82
-30
lines changed

2 files changed

+82
-30
lines changed

nibabel/arrayops.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
class OperableImage:
18-
def _op(self, other, op):
18+
def _binop(self, val, *, op):
1919
"""Apply operator to Nifti1Image.
2020
2121
Arithmetic and logical operation on Nifti image.
@@ -28,49 +28,85 @@ def _op(self, other, op):
2828
op :
2929
Python operator.
3030
"""
31-
# Check orientations are the same
32-
if aff2axcodes(self.affine) != aff2axcodes(other.affine):
33-
raise ValueError("Two images should have the same orientation")
34-
# Check affine
35-
if (self.affine != other.affine).all():
36-
raise ValueError("Two images should have the same affine.")
37-
# Check shape. Handle identical stuff for now.
38-
if self.shape != other.shape:
39-
raise ValueError("Two images should have the same shape.")
40-
41-
# Check types? Problematic types will be caught by numpy,
42-
# but might be cheaper to check before loading data.
43-
# collect dtype
44-
dtypes = [img.get_data_dtype().type for img in (self, other)]
45-
# check allowed dtype based on the operator
46-
if set(support_np_type).union(dtypes) == 0:
47-
raise ValueError("Image contains illegal datatype for Nifti1Image.")
48-
31+
val = _input_validation(self, val)
32+
# numerical operator should work work
4933
if op.__name__ in ["add", "sub", "mul", "truediv", "floordiv"]:
50-
dataobj = op(np.asanyarray(self.dataobj), np.asanyarray(other.dataobj))
34+
dataobj = op(np.asanyarray(self.dataobj), val)
5135
if op.__name__ in ["and_", "or_"]:
5236
self_ = self.dataobj.astype(bool)
53-
other_ = other.dataobj.astype(bool)
37+
other_ = val.astype(bool)
5438
dataobj = op(self_, other_).astype(int)
5539
return self.__class__(dataobj, self.affine, self.header)
5640

41+
42+
def _unop(self, *, op):
43+
"""
44+
Parameters
45+
----------
46+
op :
47+
Python operator.
48+
"""
49+
_type_check(self)
50+
if op.__name__ in ["pos", "neg", "abs"]:
51+
dataobj = op(np.asanyarray(self.dataobj))
52+
return self.__class__(dataobj, self.affine, self.header)
53+
54+
5755
def __add__(self, other):
58-
return self._op(other, operator.__add__)
56+
return self._binop(other, op=operator.__add__)
5957

6058
def __sub__(self, other):
61-
return self._op(other, operator.__sub__)
59+
return self._binop(other, op=operator.__sub__)
6260

6361
def __mul__(self, other):
64-
return self._op(other, operator.__mul__)
62+
return self._binop(other, op=operator.__mul__)
6563

6664
def __truediv__(self, other):
67-
return self._op(other, operator.__truediv__)
65+
return self._binop(other, op=operator.__truediv__)
6866

6967
def __floordiv__(self, other):
70-
return self._op(other, operator.__floordiv__)
68+
return self._binop(other, op=operator.__floordiv__)
7169

7270
def __and__(self, other):
73-
return self._op(other, operator.__and__)
71+
return self._binop(other, op=operator.__and__)
7472

7573
def __or__(self, other):
76-
return self._op(other, operator.__or__)
74+
return self._binop(other, op=operator.__or__)
75+
76+
def __pos__(self):
77+
return self._unop(op=operator.__pos__)
78+
79+
def __neg__(self):
80+
return self._unop(op=operator.__neg__)
81+
82+
def __abs__(self):
83+
return self._unop(op=operator.__abs__)
84+
85+
86+
87+
def _input_validation(self, val):
88+
"""Check images orientation, affine, and shape muti-images operation."""
89+
_type_check(self)
90+
if type(val) not in [float, int]:
91+
# Check orientations are the same
92+
if aff2axcodes(self.affine) != aff2axcodes(val.affine):
93+
raise ValueError("Two images should have the same orientation")
94+
# Check affine
95+
if (self.affine != val.affine).all():
96+
raise ValueError("Two images should have the same affine.")
97+
# Check shape.
98+
if self.shape[:3] != val.shape[:3]:
99+
raise ValueError("Two images should have the same shape except"
100+
"the time dimension.")
101+
102+
_type_check(val)
103+
val = np.asanyarray(val.dataobj)
104+
return val
105+
106+
def _type_check(*args):
107+
"""Ensure image contains correct nifti data type."""
108+
# Check types
109+
dtypes = [img.get_data_dtype().type for img in args]
110+
# check allowed dtype based on the operator
111+
if set(support_np_type).union(dtypes) == 0:
112+
raise ValueError("Image contains illegal datatype for Nifti1Image.")

nibabel/tests/test_arrayops.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@
33
from numpy.testing import assert_array_equal
44
import pytest
55

6-
def test_operations():
6+
7+
def test_binary_operations():
78
data1 = np.random.rand(5, 5, 2)
89
data2 = np.random.rand(5, 5, 2)
910
data1[0, 0, :] = 0
1011
img1 = Nifti1Image(data1, np.eye(4))
1112
img2 = Nifti1Image(data2, np.eye(4))
13+
14+
output = img1 + 2
15+
assert_array_equal(output.dataobj, data1 + 2)
16+
1217
output = img1 + img2
1318
assert_array_equal(output.dataobj, data1 + data2)
1419

@@ -37,4 +42,15 @@ def test_operations():
3742
assert_array_equal(output.dataobj, (data1.astype(bool) & data2.astype(bool)).astype(int))
3843

3944
output = img1 | img2
40-
assert_array_equal(output.dataobj, (data1.astype(bool) | data2.astype(bool)).astype(int))
45+
assert_array_equal(output.dataobj, (data1.astype(bool) | data2.astype(bool)).astype(int))
46+
47+
48+
def test_unary_operations():
49+
data = np.random.rand(5, 5, 2)
50+
img = Nifti1Image(data, np.eye(4))
51+
52+
output = -img
53+
assert_array_equal(output.dataobj, -data)
54+
55+
output = abs(img)
56+
assert_array_equal(output.dataobj, abs(data))

0 commit comments

Comments
 (0)