16
16
'assert_integral' , 'isodd' , 'assert_isinf' , 'same_sign' ,
17
17
'assert_same_sign' ]
18
18
19
- def zero (dtype ):
19
+ def zero (shape , dtype ):
20
20
"""
21
21
Returns a scalar 0 of the given dtype.
22
22
@@ -28,9 +28,9 @@ def zero(dtype):
28
28
To get -0, use -zero(dtype) (note that -0 is only defined for floating
29
29
point dtypes).
30
30
"""
31
- return zeros (() , dtype = dtype )
31
+ return zeros (shape , dtype = dtype )
32
32
33
- def one (dtype ):
33
+ def one (shape , dtype ):
34
34
"""
35
35
Returns a scalar 1 of the given dtype.
36
36
@@ -41,19 +41,19 @@ def one(dtype):
41
41
42
42
To get -1, use -one(dtype).
43
43
"""
44
- return ones (() , dtype = dtype )
44
+ return ones (shape , dtype = dtype )
45
45
46
- def NaN (dtype ):
46
+ def NaN (shape , dtype ):
47
47
"""
48
48
Returns a scalar nan of the given dtype.
49
49
50
50
Note that this is only defined for floating point dtypes.
51
51
"""
52
52
if dtype not in [float32 , float64 ]:
53
- raise RuntimeError (f"Unexpected dtype { dtype } in nan ()." )
54
- return full (() , nan , dtype = dtype )
53
+ raise RuntimeError (f"Unexpected dtype { dtype } in NaN ()." )
54
+ return full (shape , nan , dtype = dtype )
55
55
56
- def infinity (dtype ):
56
+ def infinity (shape , dtype ):
57
57
"""
58
58
Returns a scalar positive infinity of the given dtype.
59
59
@@ -64,9 +64,9 @@ def infinity(dtype):
64
64
"""
65
65
if dtype not in [float32 , float64 ]:
66
66
raise RuntimeError (f"Unexpected dtype { dtype } in infinity()." )
67
- return full (() , inf , dtype = dtype )
67
+ return full (shape , inf , dtype = dtype )
68
68
69
- def π (dtype ):
69
+ def π (shape , dtype ):
70
70
"""
71
71
Returns a scalar π.
72
72
@@ -76,24 +76,26 @@ def π(dtype):
76
76
77
77
"""
78
78
if dtype not in [float32 , float64 ]:
79
- raise RuntimeError (f"Unexpected dtype { dtype } in infinity ()." )
80
- return full (() , pi , dtype = dtype )
79
+ raise RuntimeError (f"Unexpected dtype { dtype } in π ()." )
80
+ return full (shape , pi , dtype = dtype )
81
81
82
82
def isnegzero (x ):
83
83
"""
84
84
Returns a mask where x is -0.
85
85
"""
86
86
# TODO: If copysign or signbit are added to the spec, use those instead.
87
+ shape = x .shape
87
88
dtype = x .dtype
88
- return equal (divide (one (dtype ), x ), - infinity (dtype ))
89
+ return equal (divide (one (shape , dtype ), x ), - infinity (shape , dtype ))
89
90
90
91
def isposzero (x ):
91
92
"""
92
93
Returns a mask where x is +0 (but not -0).
93
94
"""
94
95
# TODO: If copysign or signbit are added to the spec, use those instead.
96
+ shape = x .shape
95
97
dtype = x .dtype
96
- return equal (divide (one (dtype ), x ), infinity (dtype ))
98
+ return equal (divide (one (shape , dtype ), x ), infinity (shape , dtype ))
97
99
98
100
def exactly_equal (x , y ):
99
101
"""
@@ -147,13 +149,13 @@ def assert_nonzero(x):
147
149
assert all (nonzero (x )), "The input array is not nonzero"
148
150
149
151
def ispositive (x ):
150
- return greater (x , zero (x .dtype ))
152
+ return greater (x , zero (x .shape , x . dtype ))
151
153
152
154
def assert_positive (x ):
153
155
assert all (ispositive (x )), "The input array is not positive"
154
156
155
157
def isnegative (x ):
156
- return less (x , zero (x .dtype ))
158
+ return less (x , zero (x .shape , x . dtype ))
157
159
158
160
def assert_negative (x ):
159
161
assert all (isnegative (x )), "The input array is not negative"
@@ -168,7 +170,7 @@ def isintegral(x):
168
170
if x .dtype in [int8 , int16 , int32 , int64 , uint8 , uint16 , uint32 , uint64 ]:
169
171
return full (x .shape , True , dtype = bool )
170
172
elif x .dtype in [float32 , float64 ]:
171
- return equal (remainder (x , one (x .dtype )), zero (x .dtype ))
173
+ return equal (remainder (x , one (x .shape , x . dtype )), zero (x . shape , x .dtype ))
172
174
else :
173
175
return full (x .shape , False , dtype = bool )
174
176
@@ -179,7 +181,11 @@ def assert_integral(x):
179
181
assert all (isintegral (x )), "The input array has nonintegral values"
180
182
181
183
def isodd (x ):
182
- return logical_and (isintegral (x ), equal (remainder (x , 2 * one (x .dtype )), one (x .dtype )))
184
+ return logical_and (
185
+ isintegral (x ),
186
+ equal (
187
+ remainder (x , 2 * one (x .shape , x .dtype )),
188
+ one (x .shape , x .dtype )))
183
189
184
190
def assert_isinf (x ):
185
191
"""
0 commit comments