Skip to content

Commit 873bc9c

Browse files
committed
Fix some logic for pow special case tests
This also includes a fixed special case from data-apis/array-api#108.
1 parent bc330ed commit 873bc9c

File tree

2 files changed

+34
-32
lines changed

2 files changed

+34
-32
lines changed

array_api_tests/special_cases/test_pow.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, isfinite,
11-
isintegral, isodd, less, logical_and, logical_not, one, zero)
11+
isintegral, isodd, less, logical_and, logical_not, notequal, one, zero)
1212
from ..hypothesis_helpers import numeric_arrays
1313
from .._array_module import pow
1414

@@ -37,7 +37,7 @@ def test_pow_special_cases_two_args_even_if_1(arg1, arg2):
3737
3838
"""
3939
res = pow(arg1, arg2)
40-
mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
40+
mask = exactly_equal(arg2, zero(arg2.shape, arg2.dtype))
4141
assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
4242

4343

@@ -50,7 +50,7 @@ def test_pow_special_cases_two_args_even_if_2(arg1, arg2):
5050
5151
"""
5252
res = pow(arg1, arg2)
53-
mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
53+
mask = exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))
5454
assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
5555

5656

@@ -63,7 +63,7 @@ def test_pow_special_cases_two_args_equal__notequal_1(arg1, arg2):
6363
6464
"""
6565
res = pow(arg1, arg2)
66-
mask = logical_and(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), logical_not(exactly_equal(arg2, zero(arg2.shape, arg2.dtype))))
66+
mask = logical_and(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), notequal(arg2, zero(arg2.shape, arg2.dtype)))
6767
assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
6868

6969

@@ -173,19 +173,6 @@ def test_pow_special_cases_two_args_equal__greater_1(arg1, arg2):
173173

174174
@given(numeric_arrays, numeric_arrays)
175175
def test_pow_special_cases_two_args_equal__greater_2(arg1, arg2):
176-
"""
177-
Special case test for `pow(x1, x2)`:
178-
179-
- If `x1_i` is `-infinity` and `x2_i` is greater than `0`, the result is `-infinity`.
180-
181-
"""
182-
res = pow(arg1, arg2)
183-
mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype)))
184-
assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
185-
186-
187-
@given(numeric_arrays, numeric_arrays)
188-
def test_pow_special_cases_two_args_equal__greater_3(arg1, arg2):
189176
"""
190177
Special case test for `pow(x1, x2)`:
191178
@@ -223,6 +210,32 @@ def test_pow_special_cases_two_args_equal__less_2(arg1, arg2):
223210
assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
224211

225212

213+
@given(numeric_arrays, numeric_arrays)
214+
def test_pow_special_cases_two_args_equal__greater_equal_1(arg1, arg2):
215+
"""
216+
Special case test for `pow(x1, x2)`:
217+
218+
- If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-infinity`.
219+
220+
"""
221+
res = pow(arg1, arg2)
222+
mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2)))
223+
assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
224+
225+
226+
@given(numeric_arrays, numeric_arrays)
227+
def test_pow_special_cases_two_args_equal__greater_equal_2(arg1, arg2):
228+
"""
229+
Special case test for `pow(x1, x2)`:
230+
231+
- If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-0`.
232+
233+
"""
234+
res = pow(arg1, arg2)
235+
mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2)))
236+
assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
237+
238+
226239
@given(numeric_arrays, numeric_arrays)
227240
def test_pow_special_cases_two_args_equal__greater_notequal_1(arg1, arg2):
228241
"""
@@ -301,19 +314,6 @@ def test_pow_special_cases_two_args_equal__less_notequal_2(arg1, arg2):
301314
assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
302315

303316

304-
@given(numeric_arrays, numeric_arrays)
305-
def test_pow_special_cases_two_args_equal__greater_equal(arg1, arg2):
306-
"""
307-
Special case test for `pow(x1, x2)`:
308-
309-
- If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-0`.
310-
311-
"""
312-
res = pow(arg1, arg2)
313-
mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2)))
314-
assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
315-
316-
317317
@given(numeric_arrays, numeric_arrays)
318318
def test_pow_special_cases_two_args_less_equal__equal_notequal(arg1, arg2):
319319
"""

generate_stubs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ def _check_exactly_equal(typ, value):
268268

269269
def get_mask(typ, arg, value):
270270
if typ.startswith("not"):
271+
if value.startswith('zero('):
272+
return f"notequal({arg}, {value})"
271273
return f"logical_not({get_mask(typ[len('not'):], arg, value)})"
272274
if typ.startswith("abs"):
273275
return get_mask(typ[len("abs"):], f"abs({arg})", value)
@@ -517,8 +519,8 @@ def generate_special_case_test(func, typ, m, test_name_extra, sigs):
517519
assertion = get_assert("exactly_equal", result)
518520
elif typ == "TWO_ARGS_EVEN_IF":
519521
value1, result, value2 = m.groups()
520-
value1 = parse_value(value1, "arg1")
521-
mask = get_mask("exactly_equal", "arg1", value1)
522+
value1 = parse_value(value1, "arg2")
523+
mask = get_mask("exactly_equal", "arg2", value1)
522524
assertion = get_assert("exactly_equal", result)
523525
else:
524526
raise ValueError(f"Unrecognized special value type {typ}")

0 commit comments

Comments
 (0)