Skip to content

Commit e791115

Browse files
committed
use math.abs for test_sign's refimpl
1 parent 3b3ee2a commit e791115

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def unary_assert_against_refimpl(
267267
f_i = sh.fmt_idx("x", idx)
268268
f_o = sh.fmt_idx("out", idx)
269269
expr = expr_template.format(f_i, expected)
270+
# TODO: strict check floating results too
270271
if strict_check == False or res.dtype in dh.all_float_dtypes:
271272
msg = (
272273
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
@@ -1383,8 +1384,15 @@ def test_sign(x):
13831384
out = xp.sign(x)
13841385
ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype)
13851386
ph.assert_shape("sign", out_shape=out.shape, expected=x.shape)
1386-
refimpl = lambda x: x / xp.abs(x) if x != 0 else 0
1387-
unary_assert_against_refimpl("sign", x, out, refimpl, filter_=lambda s: s != 0)
1387+
refimpl = lambda x: x / math.abs(x) if x != 0 else 0
1388+
unary_assert_against_refimpl(
1389+
"sign",
1390+
x,
1391+
out,
1392+
refimpl,
1393+
filter_=lambda s: s != 0,
1394+
strict_check=True,
1395+
)
13881396

13891397

13901398
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)