15
15
16
16
17
17
class OperableImage :
18
- def _op (self , other , op ):
18
+ def _binop (self , val , * , op ):
19
19
"""Apply operator to Nifti1Image.
20
20
21
21
Arithmetic and logical operation on Nifti image.
@@ -28,49 +28,85 @@ def _op(self, other, op):
28
28
op :
29
29
Python operator.
30
30
"""
31
- # Check orientations are the same
32
- if aff2axcodes (self .affine ) != aff2axcodes (other .affine ):
33
- raise ValueError ("Two images should have the same orientation" )
34
- # Check affine
35
- if (self .affine != other .affine ).all ():
36
- raise ValueError ("Two images should have the same affine." )
37
- # Check shape. Handle identical stuff for now.
38
- if self .shape != other .shape :
39
- raise ValueError ("Two images should have the same shape." )
40
-
41
- # Check types? Problematic types will be caught by numpy,
42
- # but might be cheaper to check before loading data.
43
- # collect dtype
44
- dtypes = [img .get_data_dtype ().type for img in (self , other )]
45
- # check allowed dtype based on the operator
46
- if set (support_np_type ).union (dtypes ) == 0 :
47
- raise ValueError ("Image contains illegal datatype for Nifti1Image." )
48
-
31
+ val = _input_validation (self , val )
32
+ # numerical operator should work work
49
33
if op .__name__ in ["add" , "sub" , "mul" , "truediv" , "floordiv" ]:
50
- dataobj = op (np .asanyarray (self .dataobj ), np . asanyarray ( other . dataobj ) )
34
+ dataobj = op (np .asanyarray (self .dataobj ), val )
51
35
if op .__name__ in ["and_" , "or_" ]:
52
36
self_ = self .dataobj .astype (bool )
53
- other_ = other . dataobj .astype (bool )
37
+ other_ = val .astype (bool )
54
38
dataobj = op (self_ , other_ ).astype (int )
55
39
return self .__class__ (dataobj , self .affine , self .header )
56
40
41
+
42
+ def _unop (self , * , op ):
43
+ """
44
+ Parameters
45
+ ----------
46
+ op :
47
+ Python operator.
48
+ """
49
+ _type_check (self )
50
+ if op .__name__ in ["pos" , "neg" , "abs" ]:
51
+ dataobj = op (np .asanyarray (self .dataobj ))
52
+ return self .__class__ (dataobj , self .affine , self .header )
53
+
54
+
57
55
def __add__ (self , other ):
58
- return self ._op (other , operator .__add__ )
56
+ return self ._binop (other , op = operator .__add__ )
59
57
60
58
def __sub__ (self , other ):
61
- return self ._op (other , operator .__sub__ )
59
+ return self ._binop (other , op = operator .__sub__ )
62
60
63
61
def __mul__ (self , other ):
64
- return self ._op (other , operator .__mul__ )
62
+ return self ._binop (other , op = operator .__mul__ )
65
63
66
64
def __truediv__ (self , other ):
67
- return self ._op (other , operator .__truediv__ )
65
+ return self ._binop (other , op = operator .__truediv__ )
68
66
69
67
def __floordiv__ (self , other ):
70
- return self ._op (other , operator .__floordiv__ )
68
+ return self ._binop (other , op = operator .__floordiv__ )
71
69
72
70
def __and__ (self , other ):
73
- return self ._op (other , operator .__and__ )
71
+ return self ._binop (other , op = operator .__and__ )
74
72
75
73
def __or__ (self , other ):
76
- return self ._op (other , operator .__or__ )
74
+ return self ._binop (other , op = operator .__or__ )
75
+
76
+ def __pos__ (self ):
77
+ return self ._unop (op = operator .__pos__ )
78
+
79
+ def __neg__ (self ):
80
+ return self ._unop (op = operator .__neg__ )
81
+
82
+ def __abs__ (self ):
83
+ return self ._unop (op = operator .__abs__ )
84
+
85
+
86
+
87
+ def _input_validation (self , val ):
88
+ """Check images orientation, affine, and shape muti-images operation."""
89
+ _type_check (self )
90
+ if type (val ) not in [float , int ]:
91
+ # Check orientations are the same
92
+ if aff2axcodes (self .affine ) != aff2axcodes (val .affine ):
93
+ raise ValueError ("Two images should have the same orientation" )
94
+ # Check affine
95
+ if (self .affine != val .affine ).all ():
96
+ raise ValueError ("Two images should have the same affine." )
97
+ # Check shape.
98
+ if self .shape [:3 ] != val .shape [:3 ]:
99
+ raise ValueError ("Two images should have the same shape except"
100
+ "the time dimension." )
101
+
102
+ _type_check (val )
103
+ val = np .asanyarray (val .dataobj )
104
+ return val
105
+
106
+ def _type_check (* args ):
107
+ """Ensure image contains correct nifti data type."""
108
+ # Check types
109
+ dtypes = [img .get_data_dtype ().type for img in args ]
110
+ # check allowed dtype based on the operator
111
+ if set (support_np_type ).union (dtypes ) == 0 :
112
+ raise ValueError ("Image contains illegal datatype for Nifti1Image." )
0 commit comments