Skip to content

Commit 9ded76a

Browse files
committed
EHN allow 3D image projects along the 4th dimension
1 parent b29b783 commit 9ded76a

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

nibabel/arrayops.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@ def _binop(self, val, *, op):
2828
op :
2929
Python operator.
3030
"""
31-
val = _input_validation(self, val)
31+
affine, header = self.affine, self.header
32+
self_, val_ = _input_validation(self, val)
3233
# numerical operator should work work
34+
3335
if op.__name__ in ["add", "sub", "mul", "truediv", "floordiv"]:
34-
dataobj = op(np.asanyarray(self.dataobj), val)
36+
dataobj = op(self_, val_)
3537
if op.__name__ in ["and_", "or_"]:
36-
self_ = self.dataobj.astype(bool)
37-
other_ = val.astype(bool)
38-
dataobj = op(self_, other_).astype(int)
39-
return self.__class__(dataobj, self.affine, self.header)
38+
self_ = self_.astype(bool)
39+
val_ = val_.astype(bool)
40+
dataobj = op(self_, val_).astype(int)
41+
return self.__class__(dataobj, affine, header)
4042

4143

4244
def _unop(self, *, op):
@@ -87,7 +89,8 @@ def __abs__(self):
8789
def _input_validation(self, val):
8890
"""Check images orientation, affine, and shape muti-images operation."""
8991
_type_check(self)
90-
if type(val) not in [float, int]:
92+
if isinstance(val, self.__class__):
93+
_type_check(val)
9194
# Check orientations are the same
9295
if aff2axcodes(self.affine) != aff2axcodes(val.affine):
9396
raise ValueError("Two images should have the same orientation")
@@ -99,9 +102,27 @@ def _input_validation(self, val):
99102
raise ValueError("Two images should have the same shape except"
100103
"the time dimension.")
101104

102-
_type_check(val)
103-
val = np.asanyarray(val.dataobj)
104-
return val
105+
# if 4th dim exist in a image,
106+
# reshape the 3d image to ensure valid projection
107+
ndims = (len(self.shape), len(val.shape))
108+
if 4 not in ndims:
109+
self_ = np.asanyarray(self.dataobj)
110+
val_ = np.asanyarray(val.dataobj)
111+
return self_, val_
112+
113+
reference = None
114+
imgs = []
115+
for ndim, img in zip(ndims, (self, val)):
116+
img_ = np.asanyarray(img.dataobj)
117+
if ndim == 3:
118+
reference = tuple(list(img.shape) + [1])
119+
img_ = np.reshape(img_, reference)
120+
imgs.append(img_)
121+
return imgs
122+
else:
123+
self_ = np.asanyarray(self.dataobj)
124+
val_ = val
125+
return self_, val_
105126

106127
def _type_check(*args):
107128
"""Ensure image contains correct nifti data type."""

nibabel/tests/test_arrayops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ def test_binary_operations():
4545
assert_array_equal(output.dataobj, (data1.astype(bool) | data2.astype(bool)).astype(int))
4646

4747

48+
def test_binary_operations_4d():
49+
data1 = np.random.rand(5, 5, 2, 3)
50+
data2 = np.random.rand(5, 5, 2)
51+
img1 = Nifti1Image(data1, np.eye(4))
52+
img2 = Nifti1Image(data2, np.eye(4))
53+
data2_ = np.reshape(data2, (5, 5, 2, 1))
54+
output = img1 * img2
55+
assert_array_equal(output.dataobj, data1 * data2_)
56+
57+
4858
def test_unary_operations():
4959
data = np.random.rand(5, 5, 2)
5060
img = Nifti1Image(data, np.eye(4))

0 commit comments

Comments
 (0)