Skip to content

Commit 50a768b

Browse files
committed
Pass in the shape to the array helper creation functions
Using a scalar works for broadcasting, but it breaks masking.
1 parent 35f8906 commit 50a768b

28 files changed

+394
-388
lines changed

array_api_tests/array_helpers.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
'assert_integral', 'isodd', 'assert_isinf', 'same_sign',
1717
'assert_same_sign']
1818

19-
def zero(dtype):
19+
def zero(shape, dtype):
2020
"""
2121
Returns a scalar 0 of the given dtype.
2222
@@ -28,9 +28,9 @@ def zero(dtype):
2828
To get -0, use -zero(dtype) (note that -0 is only defined for floating
2929
point dtypes).
3030
"""
31-
return zeros((), dtype=dtype)
31+
return zeros(shape, dtype=dtype)
3232

33-
def one(dtype):
33+
def one(shape, dtype):
3434
"""
3535
Returns a scalar 1 of the given dtype.
3636
@@ -41,19 +41,19 @@ def one(dtype):
4141
4242
To get -1, use -one(dtype).
4343
"""
44-
return ones((), dtype=dtype)
44+
return ones(shape, dtype=dtype)
4545

46-
def NaN(dtype):
46+
def NaN(shape, dtype):
4747
"""
4848
Returns a scalar nan of the given dtype.
4949
5050
Note that this is only defined for floating point dtypes.
5151
"""
5252
if dtype not in [float32, float64]:
53-
raise RuntimeError(f"Unexpected dtype {dtype} in nan().")
54-
return full((), nan, dtype=dtype)
53+
raise RuntimeError(f"Unexpected dtype {dtype} in NaN().")
54+
return full(shape, nan, dtype=dtype)
5555

56-
def infinity(dtype):
56+
def infinity(shape, dtype):
5757
"""
5858
Returns a scalar positive infinity of the given dtype.
5959
@@ -64,9 +64,9 @@ def infinity(dtype):
6464
"""
6565
if dtype not in [float32, float64]:
6666
raise RuntimeError(f"Unexpected dtype {dtype} in infinity().")
67-
return full((), inf, dtype=dtype)
67+
return full(shape, inf, dtype=dtype)
6868

69-
def π(dtype):
69+
def π(shape, dtype):
7070
"""
7171
Returns a scalar π.
7272
@@ -76,24 +76,26 @@ def π(dtype):
7676
7777
"""
7878
if dtype not in [float32, float64]:
79-
raise RuntimeError(f"Unexpected dtype {dtype} in infinity().")
80-
return full((), pi, dtype=dtype)
79+
raise RuntimeError(f"Unexpected dtype {dtype} in π().")
80+
return full(shape, pi, dtype=dtype)
8181

8282
def isnegzero(x):
8383
"""
8484
Returns a mask where x is -0.
8585
"""
8686
# TODO: If copysign or signbit are added to the spec, use those instead.
87+
shape = x.shape
8788
dtype = x.dtype
88-
return equal(divide(one(dtype), x), -infinity(dtype))
89+
return equal(divide(one(shape, dtype), x), -infinity(shape, dtype))
8990

9091
def isposzero(x):
9192
"""
9293
Returns a mask where x is +0 (but not -0).
9394
"""
9495
# TODO: If copysign or signbit are added to the spec, use those instead.
96+
shape = x.shape
9597
dtype = x.dtype
96-
return equal(divide(one(dtype), x), infinity(dtype))
98+
return equal(divide(one(shape, dtype), x), infinity(shape, dtype))
9799

98100
def exactly_equal(x, y):
99101
"""
@@ -147,13 +149,13 @@ def assert_nonzero(x):
147149
assert all(nonzero(x)), "The input array is not nonzero"
148150

149151
def ispositive(x):
150-
return greater(x, zero(x.dtype))
152+
return greater(x, zero(x.shape, x.dtype))
151153

152154
def assert_positive(x):
153155
assert all(ispositive(x)), "The input array is not positive"
154156

155157
def isnegative(x):
156-
return less(x, zero(x.dtype))
158+
return less(x, zero(x.shape, x.dtype))
157159

158160
def assert_negative(x):
159161
assert all(isnegative(x)), "The input array is not negative"
@@ -168,7 +170,7 @@ def isintegral(x):
168170
if x.dtype in [int8, int16, int32, int64, uint8, uint16, uint32, uint64]:
169171
return full(x.shape, True, dtype=bool)
170172
elif x.dtype in [float32, float64]:
171-
return equal(remainder(x, one(x.dtype)), zero(x.dtype))
173+
return equal(remainder(x, one(x.shape, x.dtype)), zero(x.shape, x.dtype))
172174
else:
173175
return full(x.shape, False, dtype=bool)
174176

@@ -179,7 +181,11 @@ def assert_integral(x):
179181
assert all(isintegral(x)), "The input array has nonintegral values"
180182

181183
def isodd(x):
182-
return logical_and(isintegral(x), equal(remainder(x, 2*one(x.dtype)), one(x.dtype)))
184+
return logical_and(
185+
isintegral(x),
186+
equal(
187+
remainder(x, 2*one(x.shape, x.dtype)),
188+
one(x.shape, x.dtype)))
183189

184190
def assert_isinf(x):
185191
"""

array_api_tests/special_cases/test_abs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def test_abs_special_cases_one_arg_equal_1(arg1):
2323
2424
"""
2525
res = abs(arg1)
26-
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
26+
mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.shape, arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -36,8 +36,8 @@ def test_abs_special_cases_one_arg_equal_2(arg1):
3636
3737
"""
3838
res = abs(arg1)
39-
mask = exactly_equal(arg1, -zero(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
39+
mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.shape, arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -49,5 +49,5 @@ def test_abs_special_cases_one_arg_equal_3(arg1):
4949
5050
"""
5151
res = abs(arg1)
52-
mask = exactly_equal(arg1, -infinity(arg1.dtype))
53-
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])
52+
mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
53+
assert_exactly_equal(res[mask], infinity(arg1.shape, arg1.dtype)[mask])

array_api_tests/special_cases/test_acos.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def test_acos_special_cases_one_arg_equal_1(arg1):
2323
2424
"""
2525
res = acos(arg1)
26-
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
26+
mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.shape, arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -36,8 +36,8 @@ def test_acos_special_cases_one_arg_equal_2(arg1):
3636
3737
"""
3838
res = acos(arg1)
39-
mask = exactly_equal(arg1, one(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
39+
mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.shape, arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -49,8 +49,8 @@ def test_acos_special_cases_one_arg_greater(arg1):
4949
5050
"""
5151
res = acos(arg1)
52-
mask = greater(arg1, one(arg1.dtype))
53-
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
52+
mask = greater(arg1, one(arg1.shape, arg1.dtype))
53+
assert_exactly_equal(res[mask], NaN(arg1.shape, arg1.dtype)[mask])
5454

5555

5656
@given(numeric_arrays)
@@ -62,5 +62,5 @@ def test_acos_special_cases_one_arg_less(arg1):
6262
6363
"""
6464
res = acos(arg1)
65-
mask = less(arg1, -one(arg1.dtype))
66-
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
65+
mask = less(arg1, -one(arg1.shape, arg1.dtype))
66+
assert_exactly_equal(res[mask], NaN(arg1.shape, arg1.dtype)[mask])

array_api_tests/special_cases/test_acosh.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def test_acosh_special_cases_one_arg_equal_1(arg1):
2323
2424
"""
2525
res = acosh(arg1)
26-
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
26+
mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.shape, arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -36,8 +36,8 @@ def test_acosh_special_cases_one_arg_equal_2(arg1):
3636
3737
"""
3838
res = acosh(arg1)
39-
mask = exactly_equal(arg1, one(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
39+
mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.shape, arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -49,8 +49,8 @@ def test_acosh_special_cases_one_arg_equal_3(arg1):
4949
5050
"""
5151
res = acosh(arg1)
52-
mask = exactly_equal(arg1, infinity(arg1.dtype))
53-
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])
52+
mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
53+
assert_exactly_equal(res[mask], infinity(arg1.shape, arg1.dtype)[mask])
5454

5555

5656
@given(numeric_arrays)
@@ -62,5 +62,5 @@ def test_acosh_special_cases_one_arg_less(arg1):
6262
6363
"""
6464
res = acosh(arg1)
65-
mask = less(arg1, one(arg1.dtype))
66-
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
65+
mask = less(arg1, one(arg1.shape, arg1.dtype))
66+
assert_exactly_equal(res[mask], NaN(arg1.shape, arg1.dtype)[mask])

0 commit comments

Comments
 (0)