Skip to content

Commit 7f5b687

Browse files
committed
Add true() and false() helpers
1 parent babbb9c commit 7f5b687

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

array_api_tests/array_helpers.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
'assert_negative_mathematical_sign', 'same_sign',
2828
'assert_same_sign', 'ndindex', 'promote_dtypes', 'float64',
2929
'asarray', 'is_integer_dtype', 'is_float_dtype', 'dtype_ranges',
30-
'full']
30+
'full', 'true', 'false']
3131

3232
def zero(shape, dtype):
3333
"""
34-
Returns a scalar 0 of the given dtype.
34+
Returns a full 0 array of the given dtype.
3535
3636
This should be used in place of the literal "0" in the test suite, as the
3737
spec does not require any behavior with Python literals (and in
@@ -45,7 +45,7 @@ def zero(shape, dtype):
4545

4646
def one(shape, dtype):
4747
"""
48-
Returns a scalar 1 of the given dtype.
48+
Returns a full 1 array of the given dtype.
4949
5050
This should be used in place of the literal "1" in the test suite, as the
5151
spec does not require any behavior with Python literals (and in
@@ -58,7 +58,7 @@ def one(shape, dtype):
5858

5959
def NaN(shape, dtype):
6060
"""
61-
Returns a scalar nan of the given dtype.
61+
Returns a full nan array of the given dtype.
6262
6363
Note that this is only defined for floating point dtypes.
6464
"""
@@ -68,7 +68,7 @@ def NaN(shape, dtype):
6868

6969
def infinity(shape, dtype):
7070
"""
71-
Returns a scalar positive infinity of the given dtype.
71+
Returns a full positive infinity array of the given dtype.
7272
7373
Note that this is only defined for floating point dtypes.
7474
@@ -81,7 +81,7 @@ def infinity(shape, dtype):
8181

8282
def π(shape, dtype):
8383
"""
84-
Returns a scalar π.
84+
Returns a full π array of the given dtype.
8585
8686
Note that this function is only defined for floating point dtype.
8787
@@ -92,6 +92,18 @@ def π(shape, dtype):
9292
raise RuntimeError(f"Unexpected dtype {dtype} in π().")
9393
return full(shape, pi, dtype=dtype)
9494

95+
def true(shape):
96+
"""
97+
Returns a full True array with dtype=bool.
98+
"""
99+
return full(shape, True, dtype=bool)
100+
101+
def false(shape):
102+
"""
103+
Returns a full False array with dtype=bool.
104+
"""
105+
return full(shape, False, dtype=bool)
106+
95107
def isnegzero(x):
96108
"""
97109
Returns a mask where x is -0.

0 commit comments

Comments
 (0)