@@ -262,7 +262,7 @@ def unary_assert_against_refimpl(
262
262
expr_template = func_name + "({})={}"
263
263
in_stype = dh .get_scalar_type (in_ .dtype )
264
264
if res_stype is None :
265
- res_stype = in_stype
265
+ res_stype = dh . get_scalar_type ( res . dtype )
266
266
if res .dtype == xp .bool :
267
267
m , M = (None , None )
268
268
elif res .dtype in dh .complex_dtypes :
@@ -334,7 +334,7 @@ def binary_assert_against_refimpl(
334
334
expr_template = func_name + "({}, {})={}"
335
335
in_stype = dh .get_scalar_type (left .dtype )
336
336
if res_stype is None :
337
- res_stype = in_stype
337
+ res_stype = dh . get_scalar_type ( left . dtype )
338
338
if res_stype is None :
339
339
res_stype = in_stype
340
340
if res .dtype == xp .bool :
@@ -412,7 +412,7 @@ def right_scalar_assert_against_refimpl(
412
412
return # short-circuit here as there will be nothing to test
413
413
in_stype = dh .get_scalar_type (left .dtype )
414
414
if res_stype is None :
415
- res_stype = in_stype
415
+ res_stype = dh . get_scalar_type ( left . dtype )
416
416
if res_stype is None :
417
417
res_stype = in_stype
418
418
if res .dtype == xp .bool :
@@ -1100,6 +1100,14 @@ def test_greater_equal(ctx, data):
1100
1100
)
1101
1101
1102
1102
1103
+ @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
1104
+ def test_imag (x ):
1105
+ out = xp .imag (x )
1106
+ ph .assert_dtype ("imag" , x .dtype , out .dtype , dh .dtype_components [x .dtype ])
1107
+ ph .assert_shape ("imag" , out .shape , x .shape )
1108
+ unary_assert_against_refimpl ("imag" , x , out , operator .attrgetter ("imag" ))
1109
+
1110
+
1103
1111
@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
1104
1112
def test_isfinite (x ):
1105
1113
out = xp .isfinite (x )
@@ -1341,6 +1349,14 @@ def test_pow(ctx, data):
1341
1349
# Values testing pow is too finicky
1342
1350
1343
1351
1352
+ @given (xps .arrays (dtype = xps .complex_dtypes (), shape = hh .shapes ()))
1353
+ def test_real (x ):
1354
+ out = xp .real (x )
1355
+ ph .assert_dtype ("real" , x .dtype , out .dtype , dh .dtype_components [x .dtype ])
1356
+ ph .assert_shape ("real" , out .shape , x .shape )
1357
+ unary_assert_against_refimpl ("real" , x , out , operator .attrgetter ("real" ))
1358
+
1359
+
1344
1360
@pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , dh .real_dtypes ))
1345
1361
@given (data = st .data ())
1346
1362
def test_remainder (ctx , data ):
@@ -1366,8 +1382,7 @@ def test_round(x):
1366
1382
unary_assert_against_refimpl ("round" , x , out , round , strict_check = True )
1367
1383
1368
1384
1369
- # TODO: https://github.com/data-apis/array-api/issues/545
1370
- @given (xps .arrays (dtype = xps .real_dtypes (), shape = hh .shapes (), elements = finite_kw ))
1385
+ @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (), elements = finite_kw ))
1371
1386
def test_sign (x ):
1372
1387
out = xp .sign (x )
1373
1388
ph .assert_dtype ("sign" , x .dtype , out .dtype )
0 commit comments