Skip to content

Commit e10f45e

Browse files
committed
Fix usage of "mathematical sign" in the special cases tests
It refers to the sign bit.
1 parent 31d07a8 commit e10f45e

File tree

4 files changed

+62
-21
lines changed

4 files changed

+62
-21
lines changed

array_api_tests/array_helpers.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
'isposzero', 'exactly_equal', 'assert_exactly_equal',
1515
'assert_finite', 'assert_non_zero', 'ispositive',
1616
'assert_positive', 'isnegative', 'assert_negative', 'isintegral',
17-
'assert_integral', 'isodd', 'assert_isinf', 'same_sign',
18-
'assert_same_sign']
17+
'assert_integral', 'isodd', 'assert_isinf',
18+
'positive_mathematical_sign', 'assert_positive_mathematical_sign',
19+
'negative_mathematical_sign', 'assert_negative_mathematical_sign',
20+
'same_sign', 'assert_same_sign']
1921

2022
def zero(shape, dtype):
2123
"""
@@ -194,6 +196,34 @@ def assert_isinf(x):
194196
"""
195197
assert all(isinf(x)), "The input array is not infinite"
196198

199+
def positive_mathematical_sign(x):
200+
"""
201+
Check if x has a positive "mathematical sign"
202+
203+
The "mathematical sign" here means the sign bit is 0. This includes 0,
204+
positive finite numbers, and positive infinity. It does not include any
205+
nans, as signed nans are not required by the spec.
206+
207+
"""
208+
return logical_or(greater(x, 0), isposzero(x))
209+
210+
def assert_positive_mathematical_sign(x):
211+
assert all(positive_mathematical_sign(x)), "The input arrays do not have a positive mathematical sign"
212+
213+
def negative_mathematical_sign(x):
214+
"""
215+
Check if x has a negative "mathematical sign"
216+
217+
The "mathematical sign" here means the sign bit is 1. This includes -0,
218+
negative finite numbers, and negative infinity. It does not include any
219+
nans, as signed nans are not required by the spec.
220+
221+
"""
222+
return logical_or(less(x, 0), isnegzero(x))
223+
224+
def assert_negative_mathematical_sign(x):
225+
assert all(negative_mathematical_sign(x)), "The input arrays do not have a negative mathematical sign"
226+
197227
def same_sign(x, y):
198228
"""
199229
Check if x and y have the "same sign"
@@ -203,13 +233,9 @@ def same_sign(x, y):
203233
have the same sign. The value of this function is False if either x or y
204234
is nan, as signed nans are not required by the spec.
205235
"""
206-
logical_or(
207-
logical_and(
208-
logical_or(greater(x, 0), isposzero(x)),
209-
logical_or(greater(y, 0), isposzero(y))),
210-
logical_and(
211-
logical_or(less(x, 0), isnegzero(x)),
212-
logical_or(less(y, 0), isnegzero(y))))
236+
return logical_or(
237+
logical_and(positive_mathematical_sign(x), positive_mathematical_sign(y)),
238+
logical_and(negative_mathematical_sign(x), negative_mathematical_sign(y)))
213239

214240
def assert_same_sign(x, y):
215241
assert all(same_sign(x, y)), "The input arrays do not have the same sign"

array_api_tests/special_cases/test_divide.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
not modify it directly.
88
"""
99

10-
from ..array_helpers import (NaN, assert_exactly_equal, assert_negative, assert_positive,
11-
exactly_equal, greater, infinity, isfinite, isnegative, ispositive,
12-
less, logical_and, logical_not, logical_or, non_zero, same_sign, zero)
10+
from ..array_helpers import (NaN, assert_exactly_equal, assert_negative_mathematical_sign,
11+
assert_positive_mathematical_sign, exactly_equal, greater, infinity,
12+
isfinite, isnegative, ispositive, less, logical_and, logical_not,
13+
logical_or, non_zero, same_sign, zero)
1314
from ..hypothesis_helpers import numeric_arrays
1415
from .._array_module import divide
1516

@@ -273,7 +274,7 @@ def test_divide_special_cases_two_args_same_sign_both(arg1, arg2):
273274
"""
274275
res = divide(arg1, arg2)
275276
mask = logical_and(same_sign(arg1, arg2), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2))))
276-
assert_positive(res[mask])
277+
assert_positive_mathematical_sign(res[mask])
277278

278279

279280
@given(numeric_arrays, numeric_arrays)
@@ -286,7 +287,7 @@ def test_divide_special_cases_two_args_different_signs_both(arg1, arg2):
286287
"""
287288
res = divide(arg1, arg2)
288289
mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2))))
289-
assert_negative(res[mask])
290+
assert_negative_mathematical_sign(res[mask])
290291

291292
# TODO: Implement REMAINING test for:
292293
# - In the remaining cases, where neither `-infinity`, `+0`, `-0`, nor `NaN` is involved, the quotient must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too larger to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the operation underflows and the result is a zero of appropriate mathematical sign.

array_api_tests/special_cases/test_multiply.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
not modify it directly.
88
"""
99

10-
from ..array_helpers import (NaN, assert_exactly_equal, assert_isinf, assert_negative,
11-
assert_positive, exactly_equal, infinity, isfinite, logical_and,
12-
logical_not, logical_or, non_zero, same_sign, zero)
10+
from ..array_helpers import (NaN, assert_exactly_equal, assert_isinf,
11+
assert_negative_mathematical_sign, assert_positive_mathematical_sign,
12+
exactly_equal, infinity, isfinite, logical_and, logical_not,
13+
logical_or, non_zero, same_sign, zero)
1314
from ..hypothesis_helpers import numeric_arrays
1415
from .._array_module import multiply
1516

@@ -39,7 +40,7 @@ def test_multiply_special_cases_two_args_same_sign(arg1, arg2):
3940
"""
4041
res = multiply(arg1, arg2)
4142
mask = same_sign(arg1, arg2)
42-
assert_positive(res[mask])
43+
assert_positive_mathematical_sign(res[mask])
4344

4445

4546
@given(numeric_arrays, numeric_arrays)
@@ -52,7 +53,7 @@ def test_multiply_special_cases_two_args_different_signs(arg1, arg2):
5253
"""
5354
res = multiply(arg1, arg2)
5455
mask = logical_not(same_sign(arg1, arg2))
55-
assert_negative(res[mask])
56+
assert_negative_mathematical_sign(res[mask])
5657

5758

5859
@given(numeric_arrays, numeric_arrays)

generate_stubs.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def {sig}:{doc}
182182
# (?|...) is a branch reset (regex module only feature). It works like (?:...)
183183
# except only the matched alternative is assigned group numbers, so \1, \2, and
184184
# so on will always refer to a single match from _value.
185-
_value = r"(?|`([^`]*)`|a (finite) number|a (positive \(i\.e\., greater than `0`\) finite) number|a (negative \(i\.e\., less than `0`\) finite) number|(finite)|(positive)|(negative)|(nonzero)|(?:a )?(nonzero finite) numbers?|an (integer) value|already (integer)-valued|an (odd integer) value|an implementation-dependent approximation to `([^`]*)`(?: \(rounded\))?|a (signed (?:infinity|zero)) with the mathematical sign determined by the rule already stated above|(positive) mathematical sign|(negative) mathematical sign)"
185+
_value = r"(?|`([^`]*)`|a (finite) number|a (positive \(i\.e\., greater than `0`\) finite) number|a (negative \(i\.e\., less than `0`\) finite) number|(finite)|(positive)|(negative)|(nonzero)|(?:a )?(nonzero finite) numbers?|an (integer) value|already (integer)-valued|an (odd integer) value|an implementation-dependent approximation to `([^`]*)`(?: \(rounded\))?|a (signed (?:infinity|zero)) with the mathematical sign determined by the rule already stated above|(positive mathematical sign)|(negative mathematical sign))"
186186
SPECIAL_CASE_REGEXS = dict(
187187
ONE_ARG_EQUAL = regex.compile(rf'^- +If `x_i` is {_value}, the result is {_value}\.$'),
188188
ONE_ARG_GREATER = regex.compile(rf'^- +If `x_i` is greater than {_value}, the result is {_value}\.$'),
@@ -248,7 +248,8 @@ def parse_value(value, arg):
248248
return value
249249
elif value in ['finite', 'nonzero', 'nonzero finite',
250250
"integer", "odd integer", "positive",
251-
"negative"]:
251+
"negative", "positive mathematical sign",
252+
"negative mathematical sign"]:
252253
return value
253254
# There's no way to remove the parenthetical from the matching group in
254255
# the regular expression.
@@ -287,9 +288,15 @@ def get_mask(typ, arg, value):
287288
elif value == 'positive':
288289
_check_exactly_equal(typ, value)
289290
return f"ispositive({arg})"
291+
elif value == 'positive mathematical sign':
292+
_check_exactly_equal(typ, value)
293+
return f"positive_mathematical_sign({arg})"
290294
elif value == 'negative':
291295
_check_exactly_equal(typ, value)
292296
return f"isnegative({arg})"
297+
elif value == 'negative mathematical sign':
298+
_check_exactly_equal(typ, value)
299+
return f"negative_mathematical_sign({arg})"
293300
elif value == 'integer':
294301
_check_exactly_equal(typ, value)
295302
return f"isintegral({arg})"
@@ -311,9 +318,15 @@ def get_assert(typ, result):
311318
elif result == "positive":
312319
_check_exactly_equal(typ, result)
313320
return "assert_positive(res[mask])"
321+
elif result == "positive mathematical sign":
322+
_check_exactly_equal(typ, result)
323+
return "assert_positive_mathematical_sign(res[mask])"
314324
elif result == "negative":
315325
_check_exactly_equal(typ, result)
316326
return "assert_negative(res[mask])"
327+
elif result == "negative mathematical sign":
328+
_check_exactly_equal(typ, result)
329+
return "assert_negative_mathematical_sign(res[mask])"
317330
elif 'x_i' in result:
318331
return f"assert_{typ}(res[mask], {result.replace('x_i', 'arg1')}[mask])"
319332
elif 'x1_i' in result:

0 commit comments

Comments
 (0)