Skip to content

Commit d187759

Browse files
committed
Support "same sign except" special cases
See data-apis/array-api#92.
1 parent 8b74423 commit d187759

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

array_api_tests/special_cases/test_multiply.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,28 +31,28 @@ def test_multiply_special_cases_two_args_either(arg1, arg2):
3131

3232

3333
@given(numeric_arrays, numeric_arrays)
34-
def test_multiply_special_cases_two_args_same_sign(arg1, arg2):
34+
def test_multiply_special_cases_two_args_same_sign_except(arg1, arg2):
3535
"""
3636
Special case test for `multiply(x1, x2)`:
3737
38-
- If `x1_i` and `x2_i` have the same mathematical sign, the result has a positive mathematical sign.
38+
- If `x1_i` and `x2_i` have the same mathematical sign, the result has a positive mathematical sign, except where it is `NaN` as in the rules below.
3939
4040
"""
4141
res = multiply(arg1, arg2)
42-
mask = same_sign(arg1, arg2)
42+
mask = logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, NaN(res.shape, res.dtype))))
4343
assert_positive_mathematical_sign(res[mask])
4444

4545

4646
@given(numeric_arrays, numeric_arrays)
47-
def test_multiply_special_cases_two_args_different_signs(arg1, arg2):
47+
def test_multiply_special_cases_two_args_different_signs_except(arg1, arg2):
4848
"""
4949
Special case test for `multiply(x1, x2)`:
5050
51-
- If `x1_i` and `x2_i` have different mathematical signs, the result has a negative mathematical sign.
51+
- If `x1_i` and `x2_i` have different mathematical signs, the result has a negative mathematical sign, except where it is `NaN` as in the rules below.
5252
5353
"""
5454
res = multiply(arg1, arg2)
55-
mask = logical_not(same_sign(arg1, arg2))
55+
mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, NaN(res.shape, res.dtype))))
5656
assert_negative_mathematical_sign(res[mask])
5757

5858

generate_stubs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,10 @@ def {sig}:{doc}
212212
TWO_ARGS_EQUAL__EITHER = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is either {_value} or {_value}, the result is {_value}\.$'),
213213
TWO_ARGS_EITHER__EITHER = regex.compile(rf'^- +If `x1_i` is either {_value} or {_value} and `x2_i` is either {_value} or {_value}, the result is {_value}\.$'),
214214
TWO_ARGS_SAME_SIGN = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign, the result has a {_value}\.$'),
215+
TWO_ARGS_SAME_SIGN_EXCEPT = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign, the result has a {_value}, except where it is {_value} as in the rules below\.$'),
215216
TWO_ARGS_SAME_SIGN_BOTH = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign and are both {_value}, the result has a {_value}\.$'),
216217
TWO_ARGS_DIFFERENT_SIGNS = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs, the result has a {_value}\.$'),
218+
TWO_ARGS_DIFFERENT_SIGNS_EXCEPT = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs, the result has a {_value}, except where it is {_value} as in the rules below\.$'),
217219
TWO_ARGS_DIFFERENT_SIGNS_BOTH = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs and are both {_value}, the result has a {_value}\.$'),
218220
TWO_ARGS_EVEN_IF = regex.compile(rf'^- +If `x2_i` is {_value}, the result is {_value}, even if `x1_i` is {_value}\.$'),
219221

@@ -483,6 +485,11 @@ def generate_special_case_test(func, typ, m, test_name_extra, sigs):
483485
result, = m.groups()
484486
mask = "same_sign(arg1, arg2)"
485487
assertion = get_assert("exactly_equal", result)
488+
elif typ == "TWO_ARGS_SAME_SIGN_EXCEPT":
489+
result, value = m.groups()
490+
value = parse_value(value, "res")
491+
mask = f"logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, {value})))"
492+
assertion = get_assert("exactly_equal", result)
486493
elif typ == "TWO_ARGS_SAME_SIGN_BOTH":
487494
value, result = m.groups()
488495
mask1 = get_mask("exactly_equal", "arg1", value)
@@ -493,6 +500,11 @@ def generate_special_case_test(func, typ, m, test_name_extra, sigs):
493500
result, = m.groups()
494501
mask = "logical_not(same_sign(arg1, arg2))"
495502
assertion = get_assert("exactly_equal", result)
503+
elif typ == "TWO_ARGS_DIFFERENT_SIGNS_EXCEPT":
504+
result, value = m.groups()
505+
value = parse_value(value, "res")
506+
mask = f"logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, {value})))"
507+
assertion = get_assert("exactly_equal", result)
496508
elif typ == "TWO_ARGS_DIFFERENT_SIGNS_BOTH":
497509
value, result = m.groups()
498510
mask1 = get_mask("exactly_equal", "arg1", value)

0 commit comments

Comments
 (0)