Skip to content

Commit ae0017a

Browse files
authored
Merge pull request #227 from mtsokol/update-test-sign
Update `test_sign` for complex inputs
2 parents 8d4ccd6 + e791115 commit ae0017a

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ venv.bak/
117117
# Rope project settings
118118
.ropeproject
119119

120+
# IDE
121+
.idea/
122+
.vscode/
123+
120124
# mkdocs documentation
121125
/site
122126

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def unary_assert_against_refimpl(
267267
f_i = sh.fmt_idx("x", idx)
268268
f_o = sh.fmt_idx("out", idx)
269269
expr = expr_template.format(f_i, expected)
270+
# TODO: strict check floating results too
270271
if strict_check == False or res.dtype in dh.all_float_dtypes:
271272
msg = (
272273
f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
@@ -1383,8 +1384,14 @@ def test_sign(x):
13831384
out = xp.sign(x)
13841385
ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype)
13851386
ph.assert_shape("sign", out_shape=out.shape, expected=x.shape)
1387+
refimpl = lambda x: x / math.abs(x) if x != 0 else 0
13861388
unary_assert_against_refimpl(
1387-
"sign", x, out, lambda s: math.copysign(1, s), filter_=lambda s: s != 0
1389+
"sign",
1390+
x,
1391+
out,
1392+
refimpl,
1393+
filter_=lambda s: s != 0,
1394+
strict_check=True,
13881395
)
13891396

13901397

0 commit comments

Comments
 (0)