Skip to content

Commit 3b3ee2a

Browse files
committed
Update test_sign for complex inputs
1 parent 8d4ccd6 commit 3b3ee2a

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ venv.bak/
117117
# Rope project settings
118118
.ropeproject
119119

120+
# IDE
121+
.idea/
122+
.vscode/
123+
120124
# mkdocs documentation
121125
/site
122126

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,9 +1383,8 @@ def test_sign(x):
13831383
out = xp.sign(x)
13841384
ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype)
13851385
ph.assert_shape("sign", out_shape=out.shape, expected=x.shape)
1386-
unary_assert_against_refimpl(
1387-
"sign", x, out, lambda s: math.copysign(1, s), filter_=lambda s: s != 0
1388-
)
1386+
refimpl = lambda x: x / xp.abs(x) if x != 0 else 0
1387+
unary_assert_against_refimpl("sign", x, out, refimpl, filter_=lambda s: s != 0)
13891388

13901389

13911390
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))

0 commit comments

Comments
 (0)