Skip to content

Commit bfdad3f

Browse files
committed
Complex-related updates
* Update test_sign for complex inputs * Use result dtype for res_type inference in assert-against-refimpl utils * Test `xp.real()` and `xp.imag()`
1 parent 0208b1f commit bfdad3f

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def unary_assert_against_refimpl(
262262
expr_template = func_name + "({})={}"
263263
in_stype = dh.get_scalar_type(in_.dtype)
264264
if res_stype is None:
265-
res_stype = in_stype
265+
res_stype = dh.get_scalar_type(res.dtype)
266266
if res.dtype == xp.bool:
267267
m, M = (None, None)
268268
elif res.dtype in dh.complex_dtypes:
@@ -334,7 +334,7 @@ def binary_assert_against_refimpl(
334334
expr_template = func_name + "({}, {})={}"
335335
in_stype = dh.get_scalar_type(left.dtype)
336336
if res_stype is None:
337-
res_stype = in_stype
337+
res_stype = dh.get_scalar_type(left.dtype)
338338
if res_stype is None:
339339
res_stype = in_stype
340340
if res.dtype == xp.bool:
@@ -412,7 +412,7 @@ def right_scalar_assert_against_refimpl(
412412
return # short-circuit here as there will be nothing to test
413413
in_stype = dh.get_scalar_type(left.dtype)
414414
if res_stype is None:
415-
res_stype = in_stype
415+
res_stype = dh.get_scalar_type(left.dtype)
416416
if res_stype is None:
417417
res_stype = in_stype
418418
if res.dtype == xp.bool:
@@ -1100,6 +1100,14 @@ def test_greater_equal(ctx, data):
11001100
)
11011101

11021102

1103+
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
1104+
def test_imag(x):
1105+
out = xp.imag(x)
1106+
ph.assert_dtype("imag", x.dtype, out.dtype, dh.dtype_components[x.dtype])
1107+
ph.assert_shape("imag", out.shape, x.shape)
1108+
unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
1109+
1110+
11031111
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes()))
11041112
def test_isfinite(x):
11051113
out = xp.isfinite(x)
@@ -1341,6 +1349,14 @@ def test_pow(ctx, data):
13411349
# Values testing pow is too finicky
13421350

13431351

1352+
@given(xps.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
1353+
def test_real(x):
1354+
out = xp.real(x)
1355+
ph.assert_dtype("real", x.dtype, out.dtype, dh.dtype_components[x.dtype])
1356+
ph.assert_shape("real", out.shape, x.shape)
1357+
unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
1358+
1359+
13441360
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes))
13451361
@given(data=st.data())
13461362
def test_remainder(ctx, data):
@@ -1366,8 +1382,7 @@ def test_round(x):
13661382
unary_assert_against_refimpl("round", x, out, round, strict_check=True)
13671383

13681384

1369-
# TODO: https://github.com/data-apis/array-api/issues/545
1370-
@given(xps.arrays(dtype=xps.real_dtypes(), shape=hh.shapes(), elements=finite_kw))
1385+
@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw))
13711386
def test_sign(x):
13721387
out = xp.sign(x)
13731388
ph.assert_dtype("sign", x.dtype, out.dtype)

0 commit comments

Comments
 (0)