@@ -28,15 +28,17 @@ def _binop(self, val, *, op):
28
28
op :
29
29
Python operator.
30
30
"""
31
- val = _input_validation (self , val )
31
+ affine , header = self .affine , self .header
32
+ self_ , val_ = _input_validation (self , val )
32
33
# numerical operator should work work
34
+
33
35
if op .__name__ in ["add" , "sub" , "mul" , "truediv" , "floordiv" ]:
34
- dataobj = op (np . asanyarray ( self . dataobj ), val )
36
+ dataobj = op (self_ , val_ )
35
37
if op .__name__ in ["and_" , "or_" ]:
36
- self_ = self . dataobj .astype (bool )
37
- other_ = val .astype (bool )
38
- dataobj = op (self_ , other_ ).astype (int )
39
- return self .__class__ (dataobj , self . affine , self . header )
38
+ self_ = self_ .astype (bool )
39
+ val_ = val_ .astype (bool )
40
+ dataobj = op (self_ , val_ ).astype (int )
41
+ return self .__class__ (dataobj , affine , header )
40
42
41
43
42
44
def _unop (self , * , op ):
@@ -87,7 +89,8 @@ def __abs__(self):
87
89
def _input_validation (self , val ):
88
90
"""Check images orientation, affine, and shape muti-images operation."""
89
91
_type_check (self )
90
- if type (val ) not in [float , int ]:
92
+ if isinstance (val , self .__class__ ):
93
+ _type_check (val )
91
94
# Check orientations are the same
92
95
if aff2axcodes (self .affine ) != aff2axcodes (val .affine ):
93
96
raise ValueError ("Two images should have the same orientation" )
@@ -99,9 +102,27 @@ def _input_validation(self, val):
99
102
raise ValueError ("Two images should have the same shape except"
100
103
"the time dimension." )
101
104
102
- _type_check (val )
103
- val = np .asanyarray (val .dataobj )
104
- return val
105
+ # if 4th dim exist in a image,
106
+ # reshape the 3d image to ensure valid projection
107
+ ndims = (len (self .shape ), len (val .shape ))
108
+ if 4 not in ndims :
109
+ self_ = np .asanyarray (self .dataobj )
110
+ val_ = np .asanyarray (val .dataobj )
111
+ return self_ , val_
112
+
113
+ reference = None
114
+ imgs = []
115
+ for ndim , img in zip (ndims , (self , val )):
116
+ img_ = np .asanyarray (img .dataobj )
117
+ if ndim == 3 :
118
+ reference = tuple (list (img .shape ) + [1 ])
119
+ img_ = np .reshape (img_ , reference )
120
+ imgs .append (img_ )
121
+ return imgs
122
+ else :
123
+ self_ = np .asanyarray (self .dataobj )
124
+ val_ = val
125
+ return self_ , val_
105
126
106
127
def _type_check (* args ):
107
128
"""Ensure image contains correct nifti data type."""
0 commit comments