Skip to content

Commit d7305d0

Browse files
committed
Fix the special cases for logaddexp
data-apis/array-api#239
1 parent 449ddfb commit d7305d0

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

array_api_tests/special_cases/test_logaddexp.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,20 @@
77
not modify it directly.
88
"""
99

10-
from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, logical_or
10+
from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, logical_and,
11+
logical_not, logical_or)
1112
from ..hypothesis_helpers import numeric_arrays
1213
from .._array_module import logaddexp
1314

1415
from hypothesis import given
1516

1617

1718
@given(numeric_arrays, numeric_arrays)
18-
def test_logaddexp_special_cases_two_args_either_1(arg1, arg2):
19+
def test_logaddexp_special_cases_two_args_either(arg1, arg2):
1920
"""
2021
Special case test for `logaddexp(x1, x2)`:
2122
22-
- If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`.
23+
- If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`.
2324
2425
"""
2526
res = logaddexp(arg1, arg2)
@@ -28,13 +29,26 @@ def test_logaddexp_special_cases_two_args_either_1(arg1, arg2):
2829

2930

3031
@given(numeric_arrays, numeric_arrays)
31-
def test_logaddexp_special_cases_two_args_either_2(arg1, arg2):
32+
def test_logaddexp_special_cases_two_args_equal__notequal(arg1, arg2):
3233
"""
3334
Special case test for `logaddexp(x1, x2)`:
3435
35-
- If either `x1_i` or `x2_i` is `+infinity`, the result is `+infinity`.
36+
- If `x1_i` is `+infinity` and `x2_i` is not `NaN`, the result is `+infinity`.
3637
3738
"""
3839
res = logaddexp(arg1, arg2)
39-
mask = logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg1.shape, arg1.dtype)))
40+
mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_not(exactly_equal(arg2, NaN(arg2.shape, arg2.dtype))))
41+
assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
42+
43+
44+
@given(numeric_arrays, numeric_arrays)
45+
def test_logaddexp_special_cases_two_args_notequal__equal(arg1, arg2):
46+
"""
47+
Special case test for `logaddexp(x1, x2)`:
48+
49+
- If `x1_i` is not `NaN` and `x2_i` is `+infinity`, the result is `+infinity`.
50+
51+
"""
52+
res = logaddexp(arg1, arg2)
53+
mask = logical_and(logical_not(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
4054
assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])

generate_stubs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def {annotated_sig}:{doc}
317317
TWO_ARGS_EQUAL__LESS_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is less than {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'),
318318
TWO_ARGS_EQUAL__GREATER_EQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is greater than {_value}, and `x2_i` is {_value}, the result is {_value}\.$'),
319319
TWO_ARGS_EQUAL__GREATER_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is greater than {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'),
320-
TWO_ARGS_NOTEQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is not equal to {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
320+
TWO_ARGS_NOTEQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is not (?:equal to )?{_value} and `x2_i` is {_value}, the result is {_value}\.$'),
321321
TWO_ARGS_ABSEQUAL__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
322322
TWO_ARGS_ABSGREATER__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is greater than {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
323323
TWO_ARGS_ABSLESS__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is less than {_value} and `x2_i` is {_value}, the result is {_value}\.$'),

0 commit comments

Comments
 (0)